Route(API) lifecycle management for ASGI Python web frameworks — maintenance mode, environment gating, deprecation, rate limiting, admin panels, and more. No restarts required.
+
Route(API) lifecycle management for ASGI Python web frameworks — maintenance mode, environment gating, deprecation, rate limiting, feature flags, admin panels, and more. No restarts required.
@@ -33,6 +33,7 @@ These features are framework-agnostic and available to any adapter.
| ⏰ **Scheduled windows** | `asyncio`-native scheduler — maintenance windows activate and deactivate automatically |
| 🔔 **Webhooks** | Fire HTTP POST on every state change — built-in Slack formatter and custom formatters supported |
| 🚦 **Rate limiting** | Per-IP, per-user, per-API-key, or global counters — tiered limits, burst allowance, runtime mutation |
+| 🚩 **Feature flags** | Boolean, string, integer, float, and JSON flags — targeting rules, user segments, percentage rollouts, prerequisites, and a live evaluation stream. Built on the [OpenFeature](https://openfeature.dev/) standard |
| 🏗️ **Shield Server** | Centralised control plane for multi-service architectures — SDK clients sync state via SSE with zero per-request latency |
| 🌐 **Multi-service CLI** | `SHIELD_SERVICE` env var scopes every command; `shield services` lists connected services |
@@ -193,6 +194,70 @@ Requires `api-shield[rate-limit]`. Powered by [limits](https://limits.readthedoc
---
+## Feature flags
+
+api-shield ships a full feature flag system built on the [OpenFeature](https://openfeature.dev/) standard. All five flag types, multi-condition targeting rules, user segments, percentage rollouts, and a live evaluation stream — managed from the dashboard or CLI with no code changes.
+
+```python
+from shield.core.feature_flags.models import (
+ FeatureFlag, FlagType, FlagVariation, RolloutVariation,
+ TargetingRule, RuleClause, Operator, EvaluationContext,
+)
+
+engine.use_openfeature()
+
+# Define a boolean flag with a 20% rollout and individual targeting
+await engine.save_flag(
+ FeatureFlag(
+ key="new-checkout",
+ name="New Checkout Flow",
+ type=FlagType.BOOLEAN,
+ variations=[
+ FlagVariation(name="on", value=True),
+ FlagVariation(name="off", value=False),
+ ],
+ off_variation="off",
+ fallthrough=[
+ RolloutVariation(variation="on", weight=20_000), # 20%
+ RolloutVariation(variation="off", weight=80_000), # 80%
+ ],
+ targets={"on": ["beta_tester_1"]}, # individual targeting
+ rules=[
+ TargetingRule(
+ description="Enterprise users always get the new flow",
+ clauses=[RuleClause(attribute="plan", operator=Operator.IS, values=["enterprise"])],
+ variation="on",
+ )
+ ],
+ )
+)
+
+# Evaluate in an async route handler
+ctx = EvaluationContext(key=user_id, attributes={"plan": user.plan})
+enabled = await engine.flag_client.get_boolean_value("new-checkout", False, ctx)
+
+# Evaluate in a sync def handler (thread-safe)
+enabled = engine.sync.flag_client.get_boolean_value("new-checkout", False, {"targeting_key": user_id})
+```
+
+Manage flags and segments from the CLI:
+
+```bash
+shield flags list
+shield flags eval new-checkout --user user_123
+shield flags disable new-checkout # kill-switch
+shield flags enable new-checkout
+shield flags stream # live evaluation events
+
+shield segments create beta_users --name "Beta Users"
+shield segments include beta_users --context-key user_123,user_456
+shield segments add-rule beta_users --attribute plan --operator in --values pro,enterprise
+```
+
+Requires `api-shield[flags]`.
+
+---
+
## Framework support
api-shield is built on the **ASGI** standard. The core (`shield.core`) is completely framework-agnostic and has zero framework imports. Any ASGI framework can be supported — either via a Starlette `BaseHTTPMiddleware` (for Starlette-based frameworks) or a raw ASGI callable for frameworks like Quart and Django that implement the ASGI spec independently.
@@ -265,6 +330,7 @@ Full documentation at **[attakay78.github.io/api-shield](https://attakay78.githu
| [Tutorial](https://attakay78.github.io/api-shield/tutorial/installation/) | Get started in 5 minutes |
| [Decorators reference](https://attakay78.github.io/api-shield/reference/decorators/) | All decorator options |
| [Rate limiting](https://attakay78.github.io/api-shield/tutorial/rate-limiting/) | Per-IP, per-user, tiered limits |
+| [Feature flags](https://attakay78.github.io/api-shield/tutorial/feature-flags/) | Targeting rules, segments, rollouts, live events |
| [ShieldEngine reference](https://attakay78.github.io/api-shield/reference/engine/) | Programmatic control |
| [Backends](https://attakay78.github.io/api-shield/tutorial/backends/) | Memory, File, Redis, Shield Server, custom |
| [Admin dashboard](https://attakay78.github.io/api-shield/tutorial/admin-dashboard/) | Mounting ShieldAdmin |
diff --git a/docs/changelog.md b/docs/changelog.md
index 14d7da1..43a6a03 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
+- **Feature flags** (`api-shield[flags]`): a full feature flag system built on the [OpenFeature](https://openfeature.dev/) standard, supporting boolean, string, integer, float, and JSON flag types with multi-condition targeting rules, reusable user segments (explicit included/excluded lists plus attribute-based rules), percentage rollouts, prerequisite flags, individual user targeting, and a live SSE evaluation stream. Flags and segments are manageable from the admin dashboard (`/shield/flags`, `/shield/segments`) and the CLI (`shield flags *`, `shield segments *`) — including a new `shield segments add-rule` command and an "Add Rule" panel in the Edit Segment modal that lets operators add attribute-based targeting rules without touching code or the REST API directly.
+
- **`SHIELD_SERVICE` env var fallback on all `--service` CLI options**: `shield status`, `shield enable`, `shield disable`, `shield maintenance`, and `shield schedule` all read `SHIELD_SERVICE` automatically — set it once with `export SHIELD_SERVICE=payments-service` and every command scopes itself to that service without repeating `--service`. An explicit `--service` flag always wins.
- **`shield current-service` command**: shows the active service context from the `SHIELD_SERVICE` environment variable, or a hint to set it when the variable is absent.
- **`shield services` command**: lists all distinct service names registered with the Shield Server, so you can discover which services are connected before switching context.
diff --git a/docs/index.md b/docs/index.md
index ddea47e..ac16059 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -104,6 +104,7 @@ These features are framework-agnostic and available to every adapter.
| ⏰ **Scheduled windows** | `asyncio`-native scheduler that activates and deactivates maintenance windows automatically |
| 🔔 **Webhooks** | Fire HTTP POST on every state change, with a built-in Slack formatter and support for custom formatters |
| 🚦 **Rate limiting** | Per-IP, per-user, per-API-key, or global counters with tiered limits, burst allowance, and runtime policy mutation |
+| 🚩 **Feature flags** | Boolean, string, integer, float, and JSON flags with targeting rules, user segments, percentage rollouts, prerequisites, and a live evaluation stream — built on the OpenFeature standard |
### Framework adapters
@@ -162,7 +163,9 @@ api-shield is an **ASGI-native** library. The core (`shield.core`) is framework-
- [**Tutorial: Installation**](tutorial/installation.md): get up and running in seconds
- [**Tutorial: First Decorator**](tutorial/first-decorator.md): put your first route in maintenance mode
- [**Tutorial: Rate Limiting**](tutorial/rate-limiting.md): per-IP, per-user, tiered limits, and more
+- [**Tutorial: Feature Flags**](tutorial/feature-flags.md): targeting rules, segments, rollouts, and live events
- [**Reference: Decorators**](reference/decorators.md): full decorator API
- [**Reference: Rate Limiting**](reference/rate-limiting.md): `@rate_limit` parameters, models, and CLI commands
- [**Reference: ShieldEngine**](reference/engine.md): programmatic control
+- [**Reference: Feature Flags**](reference/feature-flags.md): full flag/segment API, models, and CLI commands
- [**Reference: CLI**](reference/cli.md): all CLI commands
diff --git a/docs/reference/feature-flags.md b/docs/reference/feature-flags.md
new file mode 100644
index 0000000..2883852
--- /dev/null
+++ b/docs/reference/feature-flags.md
@@ -0,0 +1,443 @@
+# Feature Flags Reference
+
+API reference for the feature flag system.
+
+!!! note "Optional dependency"
+ ```bash
+ uv add "api-shield[flags]"
+ ```
+
+---
+
+## Engine methods
+
+### `engine.use_openfeature()`
+
+Activate the feature flag subsystem. Call once before any flag evaluation or flag/segment CRUD.
+
+```python
+engine = make_engine()
+engine.use_openfeature()
+```
+
+---
+
+### `engine.flag_client`
+
+OpenFeature-compatible async flag client. Available after `use_openfeature()`.
+
+```python
+value = await engine.flag_client.get_boolean_value(flag_key, default, context)
+value = await engine.flag_client.get_string_value(flag_key, default, context)
+value = await engine.flag_client.get_integer_value(flag_key, default, context)
+value = await engine.flag_client.get_float_value(flag_key, default, context)
+value = await engine.flag_client.get_object_value(flag_key, default, context)
+```
+
+| Parameter | Type | Description |
+|---|---|---|
+| `flag_key` | `str` | The flag's unique key |
+| `default` | `Any` | Returned when the flag is not found or an error occurs |
+| `context` | `EvaluationContext` | Per-request context for targeting |
+
+---
+
+### `engine.sync.flag_client`
+
+Thread-safe synchronous version for `def` (non-async) route handlers.
+
+```python
+enabled = engine.sync.flag_client.get_boolean_value("my-flag", False, ctx)
+```
+
+Accepts `EvaluationContext` objects or plain dicts (`{"targeting_key": user_id, ...}`).
+
+---
+
+### `await engine.save_flag(flag)`
+
+Create or replace a flag.
+
+```python
+await engine.save_flag(FeatureFlag(key="my-flag", ...))
+```
+
+---
+
+### `await engine.get_flag(key)`
+
+Return the `FeatureFlag` for `key`, or `None` if not found.
+
+---
+
+### `await engine.list_flags()`
+
+Return all flags as a list.
+
+---
+
+### `await engine.delete_flag(key)`
+
+Delete a flag.
+
+---
+
+### `await engine.save_segment(segment)`
+
+Create or replace a segment.
+
+---
+
+### `await engine.get_segment(key)`
+
+Return the `Segment` for `key`, or `None`.
+
+---
+
+### `await engine.list_segments()`
+
+Return all segments as a list.
+
+---
+
+### `await engine.delete_segment(key)`
+
+Delete a segment.
+
+---
+
+## Models
+
+### `FeatureFlag`
+
+Definition of a feature flag.
+
+```python
+class FeatureFlag(BaseModel):
+ key: str
+ name: str
+ description: str = ""
+ type: FlagType
+ tags: list[str] = []
+
+ variations: list[FlagVariation]
+ off_variation: str
+ fallthrough: str | list[RolloutVariation]
+
+ enabled: bool = True
+ prerequisites: list[Prerequisite] = []
+ targets: dict[str, list[str]] = {}
+ rules: list[TargetingRule] = []
+ scheduled_changes: list[ScheduledChange] = []
+
+ status: FlagStatus = FlagStatus.ACTIVE
+ temporary: bool = True
+ maintainer: str | None = None
+ created_at: datetime
+ updated_at: datetime
+ created_by: str = "system"
+```
+
+| Field | Description |
+|---|---|
+| `key` | Unique identifier used in code: `get_boolean_value("my-flag", ...)` |
+| `name` | Human-readable display name |
+| `type` | `FlagType.BOOLEAN`, `STRING`, `INTEGER`, `FLOAT`, or `JSON` |
+| `variations` | All possible values; must contain at least two |
+| `off_variation` | Variation served when `enabled=False` |
+| `fallthrough` | Default when no rule matches: a variation name (`str`) or a percentage rollout (`list[RolloutVariation]`) |
+| `enabled` | Kill-switch. `False` means all requests get `off_variation` |
+| `prerequisites` | Flags that must pass before this flag's rules run |
+| `targets` | Individual targeting: `{"on": ["user_1", "user_2"]}` |
+| `rules` | Targeting rules evaluated top-to-bottom; first match wins |
+
+---
+
+### `FlagType`
+
+```python
+class FlagType(StrEnum):
+ BOOLEAN = "boolean"
+ STRING = "string"
+ INTEGER = "integer"
+ FLOAT = "float"
+ JSON = "json"
+```
+
+---
+
+### `FlagVariation`
+
+One possible value a flag can return.
+
+```python
+class FlagVariation(BaseModel):
+ name: str # e.g. "on", "off", "control", "variant_a"
+ value: bool | str | int | float | dict | list
+ description: str = ""
+```
+
+---
+
+### `RolloutVariation`
+
+One bucket in a percentage rollout (used in `fallthrough` or `TargetingRule.rollout`).
+
+```python
+class RolloutVariation(BaseModel):
+ variation: str # references FlagVariation.name
+ weight: int # share of traffic, out of 100_000 total
+```
+
+Weights in a rollout list must sum to `100_000`. Examples:
+
+| Percentage | Weight |
+|---|---|
+| 10% | `10_000` |
+| 25% | `25_000` |
+| 33.33% | `33_333` |
+| 50% | `50_000` |
+
+---
+
+### `TargetingRule`
+
+A rule that matches clauses and serves a variation.
+
+```python
+class TargetingRule(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ description: str = ""
+ clauses: list[RuleClause] = []
+ variation: str | None = None # mutually exclusive with rollout
+ rollout: list[RolloutVariation] | None = None
+ track_events: bool = False
+```
+
+Clauses within a rule are AND-ed. Rules are evaluated top-to-bottom; first match wins.
+
+---
+
+### `RuleClause`
+
+A single condition within a targeting rule.
+
+```python
+class RuleClause(BaseModel):
+ attribute: str # context attribute to inspect, e.g. "plan", "country", "email"
+ operator: Operator # comparison to apply
+ values: list[Any] # one or more values; multiple values use OR logic
+ negate: bool = False # invert the result
+```
+
+---
+
+### `Operator`
+
+All supported targeting operators.
+
+```python
+class Operator(StrEnum):
+ # Equality
+ IS = "is"
+ IS_NOT = "is_not"
+ # String
+ CONTAINS = "contains"
+ NOT_CONTAINS = "not_contains"
+ STARTS_WITH = "starts_with"
+ ENDS_WITH = "ends_with"
+ MATCHES = "matches" # Python regex
+ NOT_MATCHES = "not_matches"
+ # Numeric
+ GT = "gt"
+ GTE = "gte"
+ LT = "lt"
+ LTE = "lte"
+ # Date (ISO-8601 lexicographic)
+ BEFORE = "before"
+ AFTER = "after"
+ # Collection
+ IN = "in"
+ NOT_IN = "not_in"
+ # Segment
+ IN_SEGMENT = "in_segment"
+ NOT_IN_SEGMENT = "not_in_segment"
+ # Semantic version (requires `packaging`)
+ SEMVER_EQ = "semver_eq"
+ SEMVER_LT = "semver_lt"
+ SEMVER_GT = "semver_gt"
+```
+
+---
+
+### `Prerequisite`
+
+A flag that must evaluate to a specific variation before the dependent flag runs.
+
+```python
+class Prerequisite(BaseModel):
+ flag_key: str # key of the prerequisite flag
+ variation: str # variation the prerequisite must return
+```
+
+If the prerequisite returns any other variation, the dependent flag serves `off_variation` with reason `PREREQUISITE_FAIL`.
+
+---
+
+### `Segment`
+
+A reusable group of users for flag targeting.
+
+```python
+class Segment(BaseModel):
+ key: str
+ name: str
+ description: str = ""
+ included: list[str] = [] # context keys always in the segment
+ excluded: list[str] = [] # context keys always excluded (overrides rules + included)
+ rules: list[SegmentRule] = []
+ tags: list[str] = []
+ created_at: datetime
+ updated_at: datetime
+```
+
+**Evaluation order for context key `k`:**
+
+1. `k` in `excluded` → not in segment
+2. `k` in `included` → in segment
+3. Any `SegmentRule` matches → in segment
+4. Otherwise → not in segment
+
+---
+
+### `SegmentRule`
+
+An attribute-based rule inside a segment. Multiple segment rules are OR-ed: if any rule matches, the user is in the segment.
+
+```python
+class SegmentRule(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ description: str = ""
+ clauses: list[RuleClause] = [] # all must match (AND logic)
+```
+
+---
+
+### `EvaluationContext`
+
+Per-request context for flag targeting and rollout bucketing.
+
+```python
+class EvaluationContext(BaseModel):
+ key: str # required — user/session/org ID
+ kind: str = "user" # context kind
+ email: str | None = None
+ ip: str | None = None
+ country: str | None = None
+ app_version: str | None = None
+ attributes: dict[str, Any] = {} # any additional attributes
+```
+
+Named fields (`email`, `ip`, `country`, `app_version`) are accessible in rule clauses by the same names. Items in `attributes` are merged in and accessible by key.
+
+---
+
+### `ResolutionDetails`
+
+Full result of a flag evaluation, surfaced in hooks.
+
+```python
+class ResolutionDetails(BaseModel):
+ value: Any
+ variation: str | None = None
+ reason: EvaluationReason
+ rule_id: str | None = None # set when reason == RULE_MATCH
+ prerequisite_key: str | None = None # set when reason == PREREQUISITE_FAIL
+ error_message: str | None = None # set when reason == ERROR
+```
+
+---
+
+### `EvaluationReason`
+
+Why a specific value was returned.
+
+| Value | Description |
+|---|---|
+| `OFF` | Flag is globally disabled. `off_variation` was served. |
+| `TARGET_MATCH` | Context key was in the individual targets list. |
+| `RULE_MATCH` | A targeting rule matched. `rule_id` is set. |
+| `FALLTHROUGH` | No targeting rule matched. Default rule was served. |
+| `PREREQUISITE_FAIL` | A prerequisite flag did not return the required variation. |
+| `ERROR` | Provider or evaluation error. Default value was returned. |
+| `DEFAULT` | Flag not found. SDK default was returned. |
+
+---
+
+---
+
+## REST API
+
+When `ShieldAdmin` is mounted with `engine.use_openfeature()`, these endpoints are registered under the admin path (e.g. `/shield/api/`):
+
+### Flags
+
+| Method | Path | Description |
+|---|---|---|
+| `GET` | `/api/flags` | List all flags |
+| `POST` | `/api/flags` | Create a flag (full `FeatureFlag` body) |
+| `GET` | `/api/flags/{key}` | Get a single flag |
+| `PUT` | `/api/flags/{key}` | Replace a flag (full update) |
+| `PATCH` | `/api/flags/{key}` | Partial update |
+| `DELETE` | `/api/flags/{key}` | Delete a flag |
+| `POST` | `/api/flags/{key}/enable` | Enable (kill-switch off) |
+| `POST` | `/api/flags/{key}/disable` | Disable (kill-switch on) |
+| `POST` | `/api/flags/{key}/evaluate` | Evaluate for a given context |
+
+### Segments
+
+| Method | Path | Description |
+|---|---|---|
+| `GET` | `/api/segments` | List all segments |
+| `POST` | `/api/segments` | Create a segment |
+| `GET` | `/api/segments/{key}` | Get a single segment |
+| `PUT` | `/api/segments/{key}` | Replace a segment |
+| `DELETE` | `/api/segments/{key}` | Delete a segment |
+
+---
+
+## Evaluation algorithm
+
+The evaluator (`FlagEvaluator`) is pure Python with no I/O, unit-testable in isolation.
+
+```python
+from shield.core.feature_flags.evaluator import FlagEvaluator
+
+evaluator = FlagEvaluator(segments={"beta": beta_segment})
+result = evaluator.evaluate(flag, ctx, all_flags)
+print(result.value, result.reason)
+```
+
+**Rollout bucketing** uses SHA-1 of `"{flag_key}:{ctx.kind}:{ctx.key}"` modulo `100_000`. The same context always lands in the same bucket; bucketing is stable across restarts and deploys.
+
+**Prerequisite recursion** is limited to depth 10. Circular dependencies are rejected at write time by `engine.save_flag()`.
+
+---
+
+## Dashboard routes
+
+| URL | Page |
+|---|---|
+| `/shield/flags` | Flag list with search and status filters |
+| `/shield/flags/{key}` | Flag detail (4 tabs: Overview, Targeting, Variations, Settings) |
+| `/shield/segments` | Segment list |
+
+---
+
+## Example
+
+Example: [`examples/fastapi/feature_flags.py`](https://github.com/Attakay78/api-shield/blob/main/examples/fastapi/feature_flags.py)
+
+```bash
+uv run uvicorn examples.fastapi.feature_flags:app --reload
+```
diff --git a/docs/tutorial/feature-flags.md b/docs/tutorial/feature-flags.md
new file mode 100644
index 0000000..d1e58a6
--- /dev/null
+++ b/docs/tutorial/feature-flags.md
@@ -0,0 +1,455 @@
+# Feature Flags
+
+Feature flags (also called feature toggles) let you change your application's behavior per user without redeploying. The system is built on the [OpenFeature](https://openfeature.dev/) standard and supports boolean, string, integer, float, and JSON flags, multi-condition targeting rules, user segments, percentage rollouts, and prerequisites.
+
+!!! note "Optional dependency"
+ Feature flags require the `flags` extra:
+ ```bash
+ uv add "api-shield[flags]"
+ # or: pip install "api-shield[flags]"
+ ```
+
+---
+
+## Overview
+
+A feature flag has:
+
+- **Variations**: the possible values it can return (`on`/`off`, `"dark"`/`"light"`, `10`/`50`, etc.)
+- **Targeting**: rules that decide which variation a specific user receives
+- **Fallthrough**: the default variation when no rule matches (a fixed value or a percentage rollout)
+- **Kill-switch**: `enabled=False` skips all rules and returns the `off_variation` immediately
+
+Evaluation always follows this order:
+
+```
+1. Flag disabled? → off_variation
+2. Prerequisite flags? → off_variation if any prereq fails
+3. Individual targets? → fixed variation for specific user keys
+4. Targeting rules? → first matching rule wins
+5. Fallthrough → fixed variation or percentage bucket
+```
+
+---
+
+## Installation and setup
+
+```bash
+uv add "api-shield[flags]"
+```
+
+Call `engine.use_openfeature()` once before your first evaluation, then access the flag client through `engine.flag_client`:
+
+```python
+from shield.core.config import make_engine
+
+engine = make_engine()
+engine.use_openfeature() # activates the feature flag subsystem
+```
+
+The flag client is a standard OpenFeature client — any OpenFeature-aware code works with it directly.
+
+---
+
+## Your first flag
+
+```python
+from shield.core.feature_flags.models import (
+ FeatureFlag, FlagType, FlagVariation, RolloutVariation,
+ EvaluationContext,
+)
+
+# 1. Define and save the flag
+await engine.save_flag(
+ FeatureFlag(
+ key="new-checkout",
+ name="New Checkout Flow",
+ type=FlagType.BOOLEAN,
+ variations=[
+ FlagVariation(name="on", value=True),
+ FlagVariation(name="off", value=False),
+ ],
+ off_variation="off",
+ fallthrough=[ # 20% of users get "on"
+ RolloutVariation(variation="on", weight=20_000),
+ RolloutVariation(variation="off", weight=80_000),
+ ],
+ )
+)
+
+# 2. Evaluate it in a route handler
+ctx = EvaluationContext(key=user_id)
+enabled = await engine.flag_client.get_boolean_value("new-checkout", False, ctx)
+```
+
+Rollout weights are integers out of `100_000`. The above gives exactly 20% to `"on"` and 80% to `"off"`. Bucketing is deterministic: the same `user_id` always lands in the same bucket.
+
+---
+
+## Flag types
+
+| Type | Method | Python type |
+|---|---|---|
+| `FlagType.BOOLEAN` | `get_boolean_value` | `bool` |
+| `FlagType.STRING` | `get_string_value` | `str` |
+| `FlagType.INTEGER` | `get_integer_value` | `int` |
+| `FlagType.FLOAT` | `get_float_value` | `float` |
+| `FlagType.JSON` | `get_object_value` | `dict` / `list` |
+
+All evaluation methods share the same signature: `(flag_key, default_value, context)`.
+
+```python
+# String flag
+theme = await engine.flag_client.get_string_value("ui-theme", "light", ctx)
+
+# Integer flag
+page_size = await engine.flag_client.get_integer_value("page-size", 10, ctx)
+
+# Float flag
+discount = await engine.flag_client.get_float_value("discount-rate", 0.0, ctx)
+
+# JSON flag — returns a dict
+config = await engine.flag_client.get_object_value("feature-config", {}, ctx)
+```
+
+---
+
+## Evaluation context
+
+`EvaluationContext` identifies who is making the request. The `key` field is required; use a stable user or session identifier. Everything else is optional:
+
+```python
+ctx = EvaluationContext(
+ key=user.id, # required — used for individual targeting + rollout bucketing
+ kind="user", # optional — defaults to "user"
+ email=user.email, # accessible in rules as the "email" attribute
+ ip=request.client.host,
+ country=user.country,
+ app_version="2.3.1",
+ attributes={ # any extra attributes your rules need
+ "plan": user.plan,
+ "role": user.role,
+ },
+)
+```
+
+Named fields (`email`, `ip`, `country`, `app_version`) are accessible in targeting rules by the same names. Custom attributes go in `attributes`.
+
+---
+
+## Targeting rules
+
+Targeting rules serve a specific variation to users who match certain conditions.
+
+### Attribute-based rule
+
+```python
+from shield.core.feature_flags.models import TargetingRule, RuleClause, Operator
+
+FeatureFlag(
+ key="ui-theme",
+ ...
+ rules=[
+ TargetingRule(
+ description="Corporate users → dark theme",
+ clauses=[
+ RuleClause(
+ attribute="email",
+ operator=Operator.ENDS_WITH,
+ values=["@acme.com"],
+ )
+ ],
+ variation="dark",
+ )
+ ],
+)
+```
+
+### Multiple clauses (AND logic)
+
+All clauses within a rule must match (AND). Multiple values within one clause are OR-ed.
+
+```python
+TargetingRule(
+ description="GB Pro users → full discount",
+ clauses=[
+ RuleClause(attribute="country", operator=Operator.IS, values=["GB"]),
+ RuleClause(attribute="plan", operator=Operator.IN, values=["pro", "enterprise"]),
+ ],
+ variation="full",
+)
+```
+
+### Negation
+
+Flip the result of any clause with `negate=True`:
+
+```python
+RuleClause(attribute="plan", operator=Operator.IS, values=["free"], negate=True)
+# matches any user NOT on the free plan
+```
+
+### Available operators
+
+| Category | Operators |
+|---|---|
+| Equality | `IS`, `IS_NOT` |
+| String | `CONTAINS`, `NOT_CONTAINS`, `STARTS_WITH`, `ENDS_WITH`, `MATCHES`, `NOT_MATCHES` |
+| Numeric | `GT`, `GTE`, `LT`, `LTE` |
+| Date | `BEFORE`, `AFTER` (ISO-8601 string comparison) |
+| Collection | `IN`, `NOT_IN` |
+| Segment | `IN_SEGMENT`, `NOT_IN_SEGMENT` |
+| Semver | `SEMVER_EQ`, `SEMVER_LT`, `SEMVER_GT` |
+
+---
+
+## Individual targeting
+
+Override rules for specific users by listing their context keys in `targets`. Individual targets are evaluated after prerequisites but before rules, and always win.
+
+```python
+FeatureFlag(
+ key="new-checkout",
+ ...
+ targets={
+ "on": ["beta_tester_1", "beta_tester_2"], # these users always get "on"
+ "off": ["opted_out_user"], # this user always gets "off"
+ },
+)
+```
+
+---
+
+## Segments
+
+A segment is a named, reusable group of users. Define it once and reference it in any flag's targeting rules with `Operator.IN_SEGMENT`.
+
+### Creating a segment
+
+```python
+from shield.core.feature_flags.models import Segment, SegmentRule, RuleClause, Operator
+
+# Explicit include list
+await engine.save_segment(Segment(
+ key="beta-users",
+ name="Beta Users",
+ included=["user_123", "user_456", "user_789"],
+))
+
+# Attribute-based rules (any matching rule → user is in the segment)
+await engine.save_segment(Segment(
+ key="enterprise-plan",
+ name="Enterprise Plan",
+ rules=[
+ SegmentRule(clauses=[
+ RuleClause(attribute="plan", operator=Operator.IS, values=["enterprise"]),
+ ]),
+ ],
+))
+
+# Exclude specific users even if they match a rule
+await engine.save_segment(Segment(
+ key="paid-users",
+ name="Paid Users",
+ rules=[
+ SegmentRule(clauses=[
+ RuleClause(attribute="plan", operator=Operator.IN, values=["pro", "enterprise"]),
+ ]),
+ ],
+ excluded=["test_account", "demo_user"], # always excluded, overrides rules
+))
+```
+
+### Segment evaluation order
+
+For a given context key `k`:
+
+1. `k` in `excluded` → **not** in segment
+2. `k` in `included` → in segment
+3. Any `rules` entry matches → in segment
+4. Otherwise → not in segment
+
+!!! important "Segment key ≠ user key"
+ The segment **key** (e.g. `"beta-users"`) is the segment's identifier. To make a user with `user_id="alice"` part of this segment, add `"alice"` to `included` — or add a segment rule that matches her attributes. Simply naming the segment `"alice"` does not put her in it.
+
+### Using a segment in a flag rule
+
+```python
+TargetingRule(
+ description="Beta users get the new flow",
+ clauses=[
+ RuleClause(
+ attribute="key", # evaluates ctx.key against the segment
+ operator=Operator.IN_SEGMENT,
+ values=["beta-users"], # segment key to reference
+ )
+ ],
+ variation="on",
+)
+```
+
+### Managing segments from the dashboard
+
+Open the **Segments** page (`/shield/segments`) and click a segment key or **Edit** to:
+
+- Add or remove users from the **Included** and **Excluded** lists
+- Add **targeting rules** — attribute-based conditions evaluated when a user isn't in the explicit lists
+
+### Managing segments from the CLI
+
+```bash
+# List all segments
+shield segments list
+
+# Inspect a segment
+shield segments get beta-users
+
+# Create a segment
+shield segments create beta_users --name "Beta Users"
+
+# Add users to the included list
+shield segments include beta_users --context-key user_123,user_456
+
+# Remove users via the excluded list
+shield segments exclude beta_users --context-key opted_out_user
+
+# Add an attribute-based targeting rule
+shield segments add-rule beta_users --attribute plan --operator in --values pro,enterprise
+shield segments add-rule beta_users --attribute country --operator is --values GB --description "UK users"
+
+# Remove a rule (use 'shield segments get' to find rule IDs)
+shield segments remove-rule beta_users --rule-id
+
+# Delete a segment
+shield segments delete beta_users
+```
+
+---
+
+## Prerequisites
+
+Prerequisites let a flag depend on another flag. The dependent flag only proceeds to its rules if the prerequisite flag evaluates to a specific variation.
+
+```python
+from shield.core.feature_flags.models import Prerequisite
+
+FeatureFlag(
+ key="advanced-dashboard",
+ ...
+ prerequisites=[
+ Prerequisite(flag_key="auth-v2", variation="enabled"),
+ # advanced-dashboard only evaluates if auth-v2 → "enabled"
+ ],
+)
+```
+
+Prerequisites are recursive up to a depth of 10. Circular dependencies are prevented at write time.
+
+---
+
+## Sync evaluation (plain `def` handlers)
+
+FastAPI runs plain `def` route handlers in a thread pool. Use `engine.sync.flag_client` for thread-safe synchronous evaluation without any event loop bridging:
+
+```python
+@router.get("/dashboard")
+def dashboard(request: Request, user_id: str = "anonymous"):
+ enabled = engine.sync.flag_client.get_boolean_value(
+ "new-dashboard", False, {"targeting_key": user_id}
+ )
+ return {"new_dashboard": enabled}
+```
+
+---
+
+## Admin dashboard
+
+### Flags page (`/shield/flags`)
+
+Lists all flags with key, type, status, variations, and fallthrough. Use the search box and type/status filters to narrow the list. Click a flag key to open the detail page.
+
+### Flag detail page
+
+| Tab | Contents |
+|---|---|
+| **Overview** | Key metrics: evaluation count, rule match rate, fallthrough rate, top variations |
+| **Targeting** | Add / remove prerequisite flags; manage individual targets; add / edit / delete targeting rules |
+| **Variations** | Add, rename, or remove variations; change the fallthrough and off-variation |
+| **Settings** | Edit name, description, tags, maintainer, temporary flag flag, and scheduled changes |
+
+### Segments page (`/shield/segments`)
+
+Lists all segments with included/excluded/rules counts. Click a segment to open its detail modal, or use the **Edit** button to manage included, excluded, and targeting rules.
+
+---
+
+## CLI reference
+
+### `shield flags`
+
+```bash
+shield flags list # all flags
+shield flags get new-checkout # flag detail
+shield flags create new-checkout boolean # create (interactive prompts follow)
+shield flags enable new-checkout # enable (kill-switch off)
+shield flags disable new-checkout # disable (kill-switch on)
+shield flags delete new-checkout # permanently delete
+
+shield flags eval new-checkout --user user_123 # evaluate for a user
+
+shield flags targeting new-checkout # show targeting rules
+shield flags add-rule new-checkout \
+ --variation on \
+ --segment beta-users # add segment-based rule
+shield flags add-rule new-checkout \
+ --variation on \
+ --attribute plan --operator in --values pro,enterprise
+shield flags remove-rule new-checkout --rule-id
+
+shield flags add-prereq new-checkout --flag auth-v2 --variation enabled
+shield flags remove-prereq new-checkout --flag auth-v2
+
+shield flags target new-checkout --variation on --context-key user_123
+shield flags untarget new-checkout --context-key user_123
+
+shield flags variations new-checkout # list variations
+shield flags edit new-checkout # open interactive editor
+```
+
+### `shield segments`
+
+```bash
+shield segments list
+shield segments get beta-users
+shield segments create beta_users --name "Beta Users"
+shield segments include beta_users --context-key user_123,user_456
+shield segments exclude beta_users --context-key opted_out
+shield segments add-rule beta_users --attribute plan --operator in --values pro,enterprise
+shield segments remove-rule beta_users --rule-id
+shield segments delete beta_users
+```
+
+---
+
+## Full example
+
+Full example at [`examples/fastapi/feature_flags.py`](https://github.com/Attakay78/api-shield/blob/main/examples/fastapi/feature_flags.py), covering all five flag types, individual targeting, attribute-based rules, percentage rollouts, and sync and async evaluation.
+
+Run it with:
+
+```bash
+uv run uvicorn examples.fastapi.feature_flags:app --reload
+```
+
+Then visit:
+
+- `http://localhost:8000/docs` — Swagger UI
+- `http://localhost:8000/shield/flags` — flag management dashboard
+- `http://localhost:8000/checkout?user_id=beta_tester_1` — targeted user (always `"on"`)
+- `http://localhost:8000/checkout?user_id=anyone_else` — 20% rollout
+
+---
+
+## Next step
+
+[**Reference: Feature Flags →**](../reference/feature-flags.md)
diff --git a/docs/tutorial/installation.md b/docs/tutorial/installation.md
index 962b982..828c7bd 100644
--- a/docs/tutorial/installation.md
+++ b/docs/tutorial/installation.md
@@ -22,6 +22,9 @@ uv add "api-shield[fastapi,cli]"
# FastAPI + rate limiting
uv add "api-shield[fastapi,rate-limit]"
+# FastAPI + feature flags
+uv add "api-shield[fastapi,flags]"
+
# Everything (FastAPI adapter, Redis, dashboard, CLI, admin, rate limiting)
uv add "api-shield[all]"
```
@@ -44,6 +47,7 @@ pip install "api-shield[all]"
| `admin` | Unified `ShieldAdmin` (dashboard + REST API) | Recommended for CLI support |
| `cli` | `shield` command-line tool + httpx client | Operators managing routes from the terminal |
| `rate-limit` | `limits` library for `@rate_limit` enforcement | Any app using rate limiting |
+| `flags` | `openfeature-sdk` + `packaging` for the feature flag system | Any app using feature flags |
| `all` | All of the above | Easiest option for most projects |
---
@@ -83,6 +87,7 @@ SHIELD_SERVER_URL=http://localhost:8000/shield
---
-## Next step
+## Next steps
-[**Tutorial: Your first decorator →**](first-decorator.md)
+- [**Tutorial: Your first decorator →**](first-decorator.md)
+- [**Tutorial: Feature Flags →**](feature-flags.md)
diff --git a/examples/fastapi/feature_flags.py b/examples/fastapi/feature_flags.py
new file mode 100644
index 0000000..207f80d
--- /dev/null
+++ b/examples/fastapi/feature_flags.py
@@ -0,0 +1,434 @@
+"""FastAPI — Feature Flags Example.
+
+Demonstrates the full feature-flag API powered by OpenFeature:
+
+ * Boolean / string / integer / float / JSON flag types
+ * Async evaluation (``await engine.flag_client.get_boolean_value(...)``)
+ * Sync evaluation (``engine.sync.flag_client.get_boolean_value(...)``)
+ * EvaluationContext — per-request targeting based on user attributes
+ * Individual targeting — specific users always get a fixed variation
+ * Targeting rules — serve variations based on plan, country, app_version
+ * Percentage rollout (fallthrough) — gradual feature release
+ * Kill-switch — disable a flag globally without redeploying
+ * Live event stream — watch evaluations in real time
+
+Prerequisites:
+ pip install api-shield[flags]
+ # or:
+ uv pip install "api-shield[flags]"
+
+Run:
+ uv run uvicorn examples.fastapi.feature_flags:app --reload
+
+Then visit:
+ http://localhost:8000/docs — Swagger UI
+ http://localhost:8000/shield/ — admin dashboard (login: admin / secret)
+ http://localhost:8000/shield/flags — flag management UI
+
+Exercise the endpoints:
+ # Boolean flag — new checkout flow (async route)
+ curl "http://localhost:8000/checkout?user_id=user_123"
+
+ # Boolean flag — new checkout flow (sync/def route)
+ curl "http://localhost:8000/checkout/sync?user_id=user_123"
+
+ # String flag — UI theme selection
+ curl "http://localhost:8000/theme?user_id=beta_user_1"
+
+ # Integer flag — max results per page
+ curl "http://localhost:8000/search?user_id=pro_user_1&plan=pro"
+
+ # Float flag — discount rate for a country segment
+ curl "http://localhost:8000/pricing?user_id=uk_user_1&country=GB"
+
+ # JSON flag — feature configuration bundle
+ curl "http://localhost:8000/config?user_id=user_123"
+
+ # Targeting: individual user always gets the beta variation
+ curl "http://localhost:8000/checkout?user_id=beta_tester_1"
+
+ # Live event stream (SSE) — watch evaluations happen in real time
+ curl -N "http://localhost:8000/shield/api/flags/stream"
+
+CLI — manage flags without redeploying:
+ shield login admin # password: secret
+ shield flags list
+ shield flags get new-checkout
+ shield flags disable new-checkout # kill-switch
+ shield flags enable new-checkout # restore
+ shield flags stream # tail live evaluations
+ shield flags stream new-checkout # filter to one flag
+"""
+
+from __future__ import annotations
+
+from contextlib import asynccontextmanager
+from typing import Any
+
+from fastapi import FastAPI, Request
+
+from shield.admin import ShieldAdmin
+from shield.core.config import make_engine
+from shield.core.feature_flags.models import (
+ EvaluationContext,
+ FeatureFlag,
+ FlagType,
+ FlagVariation,
+ Operator,
+ RolloutVariation,
+ RuleClause,
+ TargetingRule,
+)
+from shield.fastapi import (
+ ShieldMiddleware,
+ ShieldRouter,
+ apply_shield_to_openapi,
+)
+
+# ---------------------------------------------------------------------------
+# Engine setup
+# ---------------------------------------------------------------------------
+
+engine = make_engine()
+engine.use_openfeature()
+
+router = ShieldRouter(engine=engine)
+
+
+# ---------------------------------------------------------------------------
+# Seed flags at startup
+# ---------------------------------------------------------------------------
+
+
+async def _seed_flags() -> None:
+ """Register all feature flags.
+
+ In production you would persist flags to a shared backend (Redis, file)
+ or manage them via the dashboard / REST API. This function is for
+ demonstration only — flags created here exist only in memory.
+ """
+
+ # ------------------------------------------------------------------
+ # 1. Boolean flag — new checkout flow
+ #
+ # Individual targeting: beta_tester_1 always sees the new flow.
+ # Fallthrough: 20% of remaining users get "on", 80% get "off".
+ # ------------------------------------------------------------------
+ await engine.save_flag(
+ FeatureFlag(
+ key="new-checkout",
+ name="New Checkout Flow",
+ description="Gradual rollout of the redesigned checkout experience.",
+ type=FlagType.BOOLEAN,
+ variations=[
+ FlagVariation(name="on", value=True, description="New flow enabled"),
+ FlagVariation(name="off", value=False, description="Legacy flow"),
+ ],
+ off_variation="off",
+ # 20 % rollout — weights out of 100_000
+ fallthrough=[
+ RolloutVariation(variation="on", weight=20_000),
+ RolloutVariation(variation="off", weight=80_000),
+ ],
+ targets={"on": ["beta_tester_1", "beta_tester_2"]},
+ )
+ )
+
+ # ------------------------------------------------------------------
+ # 2. String flag — UI theme
+ #
+ # Rule: users whose email ends with "@acme.com" always get "dark".
+ # Fallthrough: everyone else gets "light".
+ # ------------------------------------------------------------------
+ await engine.save_flag(
+ FeatureFlag(
+ key="ui-theme",
+ name="UI Theme",
+ description="Default UI theme served to users.",
+ type=FlagType.STRING,
+ variations=[
+ FlagVariation(name="light", value="light"),
+ FlagVariation(name="dark", value="dark"),
+ FlagVariation(name="system", value="system"),
+ ],
+ off_variation="light",
+ fallthrough="light",
+ rules=[
+ TargetingRule(
+ description="Corporate users → dark theme",
+ clauses=[
+ RuleClause(
+ attribute="email",
+ operator=Operator.ENDS_WITH,
+ values=["@acme.com"],
+ )
+ ],
+ variation="dark",
+ )
+ ],
+ )
+ )
+
+ # ------------------------------------------------------------------
+ # 3. Integer flag — search results per page
+ #
+ # Rule: "pro" and "enterprise" plans get 50 results.
+ # Fallthrough: free-tier users get 10.
+ # ------------------------------------------------------------------
+ await engine.save_flag(
+ FeatureFlag(
+ key="search-page-size",
+ name="Search Page Size",
+ description="Max results returned per search request.",
+ type=FlagType.INTEGER,
+ variations=[
+ FlagVariation(name="small", value=10, description="Free tier"),
+ FlagVariation(name="large", value=50, description="Pro / enterprise"),
+ ],
+ off_variation="small",
+ fallthrough="small",
+ rules=[
+ TargetingRule(
+ description="Paid plans → large page size",
+ clauses=[
+ RuleClause(
+ attribute="plan",
+ operator=Operator.IN,
+ values=["pro", "enterprise"],
+ )
+ ],
+ variation="large",
+ )
+ ],
+ )
+ )
+
+ # ------------------------------------------------------------------
+ # 4. Float flag — regional discount rate
+ #
+ # Rule: GB users get a 15 % discount.
+ # Rule: EU users get a 10 % discount.
+ # Fallthrough: no discount (0.0).
+ # ------------------------------------------------------------------
+ await engine.save_flag(
+ FeatureFlag(
+ key="discount-rate",
+ name="Regional Discount Rate",
+ description="Fractional discount applied at checkout (0.0 = none, 0.15 = 15%).",
+ type=FlagType.FLOAT,
+ variations=[
+ FlagVariation(name="none", value=0.0),
+ FlagVariation(name="eu", value=0.10),
+ FlagVariation(name="gb", value=0.15),
+ ],
+ off_variation="none",
+ fallthrough="none",
+ rules=[
+ TargetingRule(
+ description="GB → 15% discount",
+ clauses=[RuleClause(attribute="country", operator=Operator.IS, values=["GB"])],
+ variation="gb",
+ ),
+ TargetingRule(
+ description="EU → 10% discount",
+ clauses=[
+ RuleClause(
+ attribute="country",
+ operator=Operator.IN,
+ values=["DE", "FR", "NL", "SE", "PL"],
+ )
+ ],
+ variation="eu",
+ ),
+ ],
+ )
+ )
+
+ # ------------------------------------------------------------------
+ # 5. JSON flag — feature configuration bundle
+ #
+ # Returns a structured dict with multiple settings in one round-trip.
+ # Useful for feature bundles that require several related values.
+ # ------------------------------------------------------------------
+ await engine.save_flag(
+ FeatureFlag(
+ key="feature-config",
+ name="Feature Configuration Bundle",
+ description="Combined config object for the new dashboard experience.",
+ type=FlagType.JSON,
+ variations=[
+ FlagVariation(
+ name="v2",
+ value={
+ "sidebar": True,
+ "analytics": True,
+ "export_formats": ["csv", "xlsx", "json"],
+ "max_widgets": 20,
+ },
+ description="Full v2 dashboard",
+ ),
+ FlagVariation(
+ name="v1",
+ value={
+ "sidebar": False,
+ "analytics": False,
+ "export_formats": ["csv"],
+ "max_widgets": 5,
+ },
+ description="Legacy v1 dashboard",
+ ),
+ ],
+ off_variation="v1",
+ fallthrough="v1",
+ )
+ )
+
+
+# ---------------------------------------------------------------------------
+# Routes — async (def async)
+# ---------------------------------------------------------------------------
+
+
+@router.get("/checkout")
+async def checkout(request: Request, user_id: str = "anonymous"):
+ """Async route: evaluate the boolean ``new-checkout`` flag.
+
+ Pass ``?user_id=beta_tester_1`` to see individual targeting in action.
+ The flag is on a 20 % rollout for everyone else.
+ """
+ ctx = EvaluationContext(key=user_id)
+ enabled = await engine.flag_client.get_boolean_value("new-checkout", False, ctx)
+ return {
+ "user_id": user_id,
+ "new_checkout": enabled,
+ "flow": "v2" if enabled else "v1",
+ }
+
+
+@router.get("/theme")
+async def theme(request: Request, user_id: str = "anonymous", email: str = ""):
+ """Async route: evaluate the string ``ui-theme`` flag.
+
+ Pass ``?email=you@acme.com`` to trigger the corporate-user rule.
+ """
+ ctx = EvaluationContext(key=user_id, email=email or None)
+ selected_theme = await engine.flag_client.get_string_value("ui-theme", "light", ctx)
+ return {"user_id": user_id, "theme": selected_theme}
+
+
+@router.get("/search")
+async def search(request: Request, user_id: str = "anonymous", plan: str = "free"):
+ """Async route: evaluate the integer ``search-page-size`` flag.
+
+ Pass ``?plan=pro`` or ``?plan=enterprise`` to get the larger page size.
+ """
+ ctx = EvaluationContext(key=user_id, attributes={"plan": plan})
+ page_size = await engine.flag_client.get_integer_value("search-page-size", 10, ctx)
+ return {"user_id": user_id, "plan": plan, "page_size": page_size, "results": []}
+
+
+@router.get("/pricing")
+async def pricing(request: Request, user_id: str = "anonymous", country: str = "US"):
+ """Async route: evaluate the float ``discount-rate`` flag.
+
+ Pass ``?country=GB`` (15 %) or ``?country=DE`` (10 %).
+ """
+ ctx = EvaluationContext(key=user_id, country=country)
+ discount = await engine.flag_client.get_float_value("discount-rate", 0.0, ctx)
+ return {
+ "user_id": user_id,
+ "country": country,
+ "discount_rate": discount,
+ "price_usd": round(100.0 * (1 - discount), 2),
+ }
+
+
+@router.get("/config")
+async def config(request: Request, user_id: str = "anonymous"):
+ """Async route: evaluate the JSON ``feature-config`` flag.
+
+ Returns the entire configuration bundle in a single evaluation call.
+ """
+ ctx = EvaluationContext(key=user_id)
+ cfg: Any = await engine.flag_client.get_object_value(
+ "feature-config", {"sidebar": False, "analytics": False}, ctx
+ )
+ return {"user_id": user_id, "config": cfg}
+
+
+# ---------------------------------------------------------------------------
+# Routes — sync (def, no async)
+# ---------------------------------------------------------------------------
+# FastAPI runs plain ``def`` handlers in a thread pool.
+# ``engine.sync.flag_client`` provides a thread-safe synchronous facade over
+# the same OpenFeature client — no asyncio bridge needed because flag
+# evaluation is pure Python with no I/O.
+# ---------------------------------------------------------------------------
+
+
+@router.get("/checkout/sync")
+def checkout_sync(request: Request, user_id: str = "anonymous"):
+ """Sync route: evaluate the ``new-checkout`` flag from a ``def`` handler.
+
+ Identical result to ``GET /checkout`` — use whichever matches your handler style.
+ """
+ enabled = engine.sync.flag_client.get_boolean_value(
+ "new-checkout", False, {"targeting_key": user_id}
+ )
+ return {
+ "user_id": user_id,
+ "new_checkout": enabled,
+ "flow": "v2" if enabled else "v1",
+ "evaluated_in": "sync",
+ }
+
+
+@router.get("/search/sync")
+def search_sync(request: Request, user_id: str = "anonymous", plan: str = "free"):
+ """Sync route: evaluate the ``search-page-size`` flag from a ``def`` handler."""
+ page_size = engine.sync.flag_client.get_integer_value(
+ "search-page-size", 10, {"targeting_key": user_id, "plan": plan}
+ )
+ return {
+ "user_id": user_id,
+ "plan": plan,
+ "page_size": page_size,
+ "evaluated_in": "sync",
+ }
+
+
+# ---------------------------------------------------------------------------
+# App assembly
+# ---------------------------------------------------------------------------
+
+
+@asynccontextmanager
+async def lifespan(_: FastAPI):
+ await _seed_flags()
+ yield
+
+
+app = FastAPI(
+ title="api-shield — Feature Flags Example",
+ description=(
+ "Demonstrates boolean, string, integer, float, and JSON flags with "
+ "targeting rules, rollouts, kill-switches, and live event streaming.\n\n"
+ "Requires `api-shield[flags]` (`pip install api-shield[flags]`)."
+ ),
+ lifespan=lifespan,
+)
+
+app.add_middleware(ShieldMiddleware, engine=engine)
+app.include_router(router)
+apply_shield_to_openapi(app, engine)
+
+app.mount(
+ "/shield",
+ ShieldAdmin(
+ engine=engine,
+ auth=("admin", "secret"),
+ prefix="/shield",
+ # enable_flags is auto-detected from engine.use_openfeature() — no
+ # need to set it explicitly. Set to True/False to override.
+ ),
+)
diff --git a/examples/fastapi/multi_service.py b/examples/fastapi/multi_service.py
index bf015a9..a999563 100644
--- a/examples/fastapi/multi_service.py
+++ b/examples/fastapi/multi_service.py
@@ -8,13 +8,13 @@
This file defines THREE separate ASGI apps. Run each in its own terminal:
Shield Server (port 8001):
- uv run uvicorn examples.fastapi.multi_service:shield_app --port 8001 --reload
+ uv run --with uvicorn uvicorn examples.fastapi.multi_service:shield_app --port 8001 --reload
Payments service (port 8000):
- uv run uvicorn examples.fastapi.multi_service:payments_app --port 8000 --reload
+ uv run --with uvicorn uvicorn examples.fastapi.multi_service:payments_app --port 8000 --reload
Orders service (port 8002):
- uv run uvicorn examples.fastapi.multi_service:orders_app --port 8002 --reload
+ uv run --with uvicorn uvicorn examples.fastapi.multi_service:orders_app --port 8002 --reload
Then visit:
http://localhost:8001/ — Shield dashboard (admin / secret)
@@ -98,6 +98,7 @@
disabled,
force_active,
maintenance,
+ setup_shield_docs,
)
from shield.sdk import ShieldSDK
from shield.server import ShieldServer
@@ -134,6 +135,8 @@
payments_sdk = ShieldSDK(
server_url="http://localhost:8001",
app_id="payments-service",
+ username="admin",
+ password="secret",
# Auto-login (recommended): SDK obtains a 1-year sdk-platform token on startup.
# username="admin", # inject from env: os.environ["SHIELD_USERNAME"]
# password="secret", # inject from env: os.environ["SHIELD_PASSWORD"]
@@ -201,6 +204,7 @@ async def v2_invoices():
payments_app.include_router(payments_router)
apply_shield_to_openapi(payments_app, payments_sdk.engine)
+setup_shield_docs(payments_app, payments_sdk.engine)
# ---------------------------------------------------------------------------
# Orders Service (port 8002)
@@ -213,6 +217,8 @@ async def v2_invoices():
orders_sdk = ShieldSDK(
server_url="http://localhost:8001",
app_id="orders-service",
+ username="admin",
+ password="secret",
# Auto-login (recommended): SDK obtains a 1-year sdk-platform token on startup.
# username="admin", # inject from env: os.environ["SHIELD_USERNAME"]
# password="secret", # inject from env: os.environ["SHIELD_PASSWORD"]
@@ -278,6 +284,7 @@ async def get_cart():
orders_app.include_router(orders_router)
apply_shield_to_openapi(orders_app, orders_sdk.engine)
+setup_shield_docs(orders_app, orders_sdk.engine)
# ---------------------------------------------------------------------------
# CLI reference — multi-service workflow
diff --git a/examples/fastapi/shield_server.py b/examples/fastapi/shield_server.py
index 6d9798d..bc0e556 100644
--- a/examples/fastapi/shield_server.py
+++ b/examples/fastapi/shield_server.py
@@ -133,6 +133,8 @@
sdk = ShieldSDK(
server_url="http://localhost:8001",
app_id="my-service",
+ username="admin",
+ password="secret",
# username="admin", # or inject from env: os.environ["SHIELD_USERNAME"]
# password="secret", # or inject from env: os.environ["SHIELD_PASSWORD"]
reconnect_delay=5.0, # seconds between SSE reconnect attempts
diff --git a/mkdocs.yml b/mkdocs.yml
index 77f927a..ba7bd43 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -98,11 +98,13 @@ nav:
- Adding Middleware: tutorial/middleware.md
- Backends: tutorial/backends.md
- Rate Limiting: tutorial/rate-limiting.md
+ - Feature Flags: tutorial/feature-flags.md
- Admin Dashboard: tutorial/admin-dashboard.md
- CLI: tutorial/cli.md
- Reference:
- Decorators: reference/decorators.md
- Rate Limiting: reference/rate-limiting.md
+ - Feature Flags: reference/feature-flags.md
- ShieldEngine: reference/engine.md
- Backends: reference/backends.md
- Middleware: reference/middleware.md
diff --git a/pyproject.toml b/pyproject.toml
index 83e6a74..af82cc6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,7 +24,7 @@ classifiers = [
dependencies = [
"pydantic>=2.0",
"anyio>=4.0",
- "starlette>=0.27",
+ "starlette>=0.40",
]
[project.urls]
@@ -33,7 +33,7 @@ Repository = "https://github.com/Attakay78/api-shield"
Issues = "https://github.com/Attakay78/api-shield/issues"
[project.optional-dependencies]
-fastapi = ["fastapi>=0.100"]
+fastapi = ["fastapi>=0.115"]
redis = ["redis[asyncio]>=5.0"]
dashboard = [
"jinja2>=3.1",
@@ -49,8 +49,12 @@ admin = [
yaml = ["pyyaml>=6.0"]
toml = ["tomli-w>=1.0"]
rate-limit = ["limits>=5.8.0"]
+flags = [
+ "openfeature-sdk>=0.8",
+ "packaging>=23.0",
+]
all = [
- "fastapi>=0.100",
+ "fastapi>=0.115",
"redis[asyncio]>=5.0",
"jinja2>=3.1",
"aiofiles>=23.0",
@@ -61,6 +65,8 @@ all = [
"tomli-w>=1.0",
"python-multipart>=0.0.22",
"limits>=5.8.0",
+ "openfeature-sdk>=0.8",
+ "packaging>=23.0",
]
docs = [
"mkdocs-material>=9.5",
@@ -75,11 +81,13 @@ dev = [
"ruff",
"mypy",
"aiofiles>=23.0",
- "fastapi>=0.100",
+ "fastapi>=0.115",
"pre-commit>=3.7",
"tomli-w>=1.0",
"pyyaml>=6.0",
"limits>=5.8.0",
+ "openfeature-sdk>=0.8",
+ "packaging>=23.0",
"mkdocs-material>=9.5",
"mkdocstrings[python]>=0.25",
"mkdocs-git-revision-date-localized-plugin>=1.2",
@@ -106,6 +114,11 @@ target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "UP"]
+[tool.ruff.lint.per-file-ignores]
+"shield/core/feature_flags/__init__.py" = ["E402"]
+"shield/core/feature_flags/hooks.py" = ["E402"]
+"shield/core/feature_flags/provider.py" = ["E402"]
+
[tool.mypy]
python_version = "3.11"
strict = true
diff --git a/shield/admin/api.py b/shield/admin/api.py
index 6fa1a65..3736015 100644
--- a/shield/admin/api.py
+++ b/shield/admin/api.py
@@ -588,11 +588,29 @@ async def _feed_rl_policies() -> None:
except Exception:
logger.exception("shield: SDK SSE RL policy subscription error")
+ async def _feed_flags() -> None:
+ try:
+ async for event in engine.backend.subscribe_flag_changes(): # type: ignore[attr-defined]
+ envelope = _json.dumps(event)
+ await queue.put(f"data: {envelope}\n\n")
+ except NotImplementedError:
+ pass
+ except asyncio.CancelledError:
+ raise
+ except Exception:
+ logger.exception("shield: SDK SSE flag subscription error")
+
async def _generate() -> object:
tasks.append(asyncio.create_task(_feed_states()))
tasks.append(asyncio.create_task(_feed_rl_policies()))
+ tasks.append(asyncio.create_task(_feed_flags()))
try:
while True:
+ # Check for client disconnect before blocking on the queue.
+ # is_disconnected() polls receive() with a 1 ms timeout so it
+ # never blocks the loop for more than a millisecond.
+ if await request.is_disconnected():
+ break
try:
# Block until an event arrives or 15 s elapses.
msg = await asyncio.wait_for(queue.get(), timeout=15.0)
@@ -605,6 +623,10 @@ async def _generate() -> object:
finally:
for t in tasks:
t.cancel()
+ # Await the feeder tasks so their finally blocks (which deregister
+ # subscriber queues) run before this handler returns. Errors are
+ # suppressed — we only care that cleanup completes.
+ await asyncio.gather(*tasks, return_exceptions=True)
return StreamingResponse(
_generate(), # type: ignore[arg-type]
@@ -704,3 +726,338 @@ async def sdk_audit(request: Request) -> JSONResponse:
await engine.backend.write_audit(entry)
return JSONResponse({"ok": True})
+
+
+# ---------------------------------------------------------------------------
+# Feature flag endpoints
+# ---------------------------------------------------------------------------
+#
+# These endpoints are only mounted when ShieldAdmin(enable_flags=True).
+# They require the [flags] optional extra to be installed — callers get a
+# clear 501 error if the extra is missing.
+# ---------------------------------------------------------------------------
+
+
+def _flags_not_configured() -> JSONResponse:
+ return JSONResponse(
+ {
+ "error": (
+ "Feature flags are not enabled. "
+ "Call engine.use_openfeature() and set enable_flags=True on ShieldAdmin."
+ )
+ },
+ status_code=501,
+ )
+
+
+def _flags_not_installed() -> JSONResponse:
+ return JSONResponse(
+ {
+ "error": (
+ "Feature flags require the [flags] extra. "
+ "Install with: pip install api-shield[flags]"
+ )
+ },
+ status_code=501,
+ )
+
+
+def _flag_models_available() -> bool:
+ """Return True if the openfeature extra is installed."""
+ try:
+ import openfeature # noqa: F401
+
+ return True
+ except ImportError:
+ return False
+
+
+async def list_flags(request: Request) -> JSONResponse:
+ """GET /api/flags — list all feature flags."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ flags = await _engine(request).list_flags()
+ return JSONResponse([f.model_dump(mode="json") for f in flags])
+
+
+async def get_flag(request: Request) -> JSONResponse:
+ """GET /api/flags/{key} — get a single feature flag."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ flag = await _engine(request).get_flag(key)
+ if flag is None:
+ return _err(f"Flag '{key}' not found", 404)
+ return JSONResponse(flag.model_dump(mode="json"))
+
+
+async def create_flag(request: Request) -> JSONResponse:
+ """POST /api/flags — create a new feature flag."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ try:
+ body = await request.json()
+ except Exception:
+ return _err("Invalid JSON body")
+
+ try:
+ from shield.core.feature_flags.models import FeatureFlag
+
+ flag = FeatureFlag.model_validate(body)
+ except Exception as exc:
+ return _err(f"Invalid flag definition: {exc}")
+
+ # Conflict check
+ existing = await _engine(request).get_flag(flag.key)
+ if existing is not None:
+ return _err(f"Flag '{flag.key}' already exists. Use PUT to update.", 409)
+
+ await _engine(request).save_flag(flag)
+ return JSONResponse(flag.model_dump(mode="json"), status_code=201)
+
+
+async def update_flag(request: Request) -> JSONResponse:
+ """PUT /api/flags/{key} — replace a feature flag (full update)."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ try:
+ body = await request.json()
+ except Exception:
+ return _err("Invalid JSON body")
+
+ # Key in URL must match key in body if provided.
+ if isinstance(body, dict) and body.get("key", key) != key:
+ return _err("Flag key in URL and body must match")
+
+ if isinstance(body, dict):
+ body["key"] = key
+
+ try:
+ from shield.core.feature_flags.models import FeatureFlag
+
+ flag = FeatureFlag.model_validate(body)
+ except Exception as exc:
+ return _err(f"Invalid flag definition: {exc}")
+
+ await _engine(request).save_flag(flag)
+ return JSONResponse(flag.model_dump(mode="json"))
+
+
+async def patch_flag(request: Request) -> JSONResponse:
+ """PATCH /api/flags/{key} — partial update of a feature flag."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ flag = await _engine(request).get_flag(key)
+ if flag is None:
+ return _err(f"Flag '{key}' not found", 404)
+ try:
+ body = await request.json()
+ except Exception:
+ return _err("Invalid JSON body")
+ if not isinstance(body, dict):
+ return _err("Body must be a JSON object")
+
+ # Never allow patching immutable fields
+ for immutable in ("key", "type"):
+ body.pop(immutable, None)
+
+ try:
+ from shield.core.feature_flags.models import FeatureFlag
+
+ # Build updated flag by merging patch onto existing
+ current = flag.model_dump(mode="python")
+ current.update(body)
+ updated = FeatureFlag.model_validate(current)
+ except Exception as exc:
+ return _err(f"Invalid patch: {exc}")
+
+ # Cross-field validation: off_variation and string fallthrough must name
+ # an existing variation (the model doesn't enforce this itself).
+ variation_names = {v.name for v in updated.variations}
+ if updated.off_variation not in variation_names:
+ return _err(f"off_variation '{updated.off_variation}' does not match any variation name")
+ if isinstance(updated.fallthrough, str) and updated.fallthrough not in variation_names:
+ return _err(f"fallthrough '{updated.fallthrough}' does not match any variation name")
+
+ await _engine(request).save_flag(updated)
+ return JSONResponse(updated.model_dump(mode="json"))
+
+
+async def enable_flag(request: Request) -> JSONResponse:
+ """POST /api/flags/{key}/enable — enable a feature flag."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ flag = await _engine(request).get_flag(key)
+ if flag is None:
+ return _err(f"Flag '{key}' not found", 404)
+ flag = flag.model_copy(update={"enabled": True})
+ await _engine(request).save_flag(flag)
+ return JSONResponse(flag.model_dump(mode="json"))
+
+
+async def disable_flag(request: Request) -> JSONResponse:
+ """POST /api/flags/{key}/disable — disable a feature flag."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ flag = await _engine(request).get_flag(key)
+ if flag is None:
+ return _err(f"Flag '{key}' not found", 404)
+ flag = flag.model_copy(update={"enabled": False})
+ await _engine(request).save_flag(flag)
+ return JSONResponse(flag.model_dump(mode="json"))
+
+
+async def delete_flag(request: Request) -> JSONResponse:
+ """DELETE /api/flags/{key} — delete a feature flag."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ existing = await _engine(request).get_flag(key)
+ if existing is None:
+ return _err(f"Flag '{key}' not found", 404)
+ await _engine(request).delete_flag(key)
+ return JSONResponse({"ok": True, "deleted": key})
+
+
+async def evaluate_flag(request: Request) -> JSONResponse:
+ """POST /api/flags/{key}/evaluate — evaluate a flag for a given context.
+
+ Body: ``{"default": , "context": {"key": "user_1", "attributes": {...}}}``
+
+ Returns the resolved value, variation, reason, and any metadata.
+ Useful for debugging targeting rules from the dashboard or CLI.
+ """
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+
+ flag = await _engine(request).get_flag(key)
+ if flag is None:
+ return _err(f"Flag '{key}' not found", 404)
+
+ try:
+ body = await request.json()
+ except Exception:
+ body = {}
+
+ ctx_data = body.get("context", {}) if isinstance(body, dict) else {}
+
+ try:
+ from shield.core.feature_flags.evaluator import FlagEvaluator
+ from shield.core.feature_flags.models import EvaluationContext
+
+ ctx = EvaluationContext.model_validate({"key": "anonymous", **ctx_data})
+ engine = _engine(request)
+ # Gather all flags and segments from the engine for prerequisite resolution.
+ all_flags_list = await engine.list_flags()
+ all_flags = {f.key: f for f in all_flags_list}
+ segments_list = await engine.list_segments()
+ segments = {s.key: s for s in segments_list}
+
+ evaluator = FlagEvaluator(segments=segments)
+ result = evaluator.evaluate(flag, ctx, all_flags)
+ except Exception as exc:
+ return _err(f"Evaluation error: {exc}", 500)
+
+ return JSONResponse(
+ {
+ "flag_key": key,
+ "value": result.value,
+ "variation": result.variation,
+ "reason": result.reason.value,
+ "rule_id": result.rule_id,
+ "prerequisite_key": result.prerequisite_key,
+ "error_message": result.error_message,
+ }
+ )
+
+
+# ---------------------------------------------------------------------------
+# Segment endpoints
+# ---------------------------------------------------------------------------
+
+
+async def list_segments(request: Request) -> JSONResponse:
+ """GET /api/segments — list all segments."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ segments = await _engine(request).list_segments()
+ return JSONResponse([s.model_dump(mode="json") for s in segments])
+
+
+async def get_segment(request: Request) -> JSONResponse:
+ """GET /api/segments/{key} — get a single segment."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ segment = await _engine(request).get_segment(key)
+ if segment is None:
+ return _err(f"Segment '{key}' not found", 404)
+ return JSONResponse(segment.model_dump(mode="json"))
+
+
+async def create_segment(request: Request) -> JSONResponse:
+ """POST /api/segments — create a new segment."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ try:
+ body = await request.json()
+ except Exception:
+ return _err("Invalid JSON body")
+
+ try:
+ from shield.core.feature_flags.models import Segment
+
+ segment = Segment.model_validate(body)
+ except Exception as exc:
+ return _err(f"Invalid segment definition: {exc}")
+
+ existing = await _engine(request).get_segment(segment.key)
+ if existing is not None:
+ return _err(f"Segment '{segment.key}' already exists. Use PUT to update.", 409)
+
+ await _engine(request).save_segment(segment)
+ return JSONResponse(segment.model_dump(mode="json"), status_code=201)
+
+
+async def update_segment(request: Request) -> JSONResponse:
+ """PUT /api/segments/{key} — replace a segment (full update)."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ try:
+ body = await request.json()
+ except Exception:
+ return _err("Invalid JSON body")
+
+ if isinstance(body, dict) and body.get("key", key) != key:
+ return _err("Segment key in URL and body must match")
+
+ if isinstance(body, dict):
+ body["key"] = key
+
+ try:
+ from shield.core.feature_flags.models import Segment
+
+ segment = Segment.model_validate(body)
+ except Exception as exc:
+ return _err(f"Invalid segment definition: {exc}")
+
+ await _engine(request).save_segment(segment)
+ return JSONResponse(segment.model_dump(mode="json"))
+
+
+async def delete_segment(request: Request) -> JSONResponse:
+ """DELETE /api/segments/{key} — delete a segment."""
+ if not _flag_models_available():
+ return _flags_not_installed()
+ key = request.path_params["key"]
+ existing = await _engine(request).get_segment(key)
+ if existing is None:
+ return _err(f"Segment '{key}' not found", 404)
+ await _engine(request).delete_segment(key)
+ return JSONResponse({"ok": True, "deleted": key})
diff --git a/shield/admin/app.py b/shield/admin/app.py
index 39e773a..50af3d7 100644
--- a/shield/admin/app.py
+++ b/shield/admin/app.py
@@ -211,6 +211,70 @@ async def _logout(request: Request) -> Response:
return response
+def _flag_dashboard_modal_routes() -> list[Route]:
+ """Flag + segment modal routes that must be registered BEFORE the generic wildcard.
+
+ These must appear before ``Route("/modal/{action}/{path_key}", ...)`` in the
+ route list so Starlette's first-match routing picks the specific handler.
+ """
+ return [
+ Route("/modal/flag/create", _dash.modal_flag_create, methods=["GET"]),
+ Route("/modal/flag/{key}/eval", _dash.modal_flag_eval, methods=["GET"]),
+ Route("/modal/segment/create", _dash.modal_segment_create, methods=["GET"]),
+ Route("/modal/segment/{key}/view", _dash.modal_segment_view, methods=["GET"]),
+ Route("/modal/segment/{key}", _dash.modal_segment_detail, methods=["GET"]),
+ ]
+
+
+def _flag_dashboard_routes() -> list[Route]:
+ """Return the flag + segment dashboard UI routes for mounting in ShieldAdmin."""
+ return [
+ Route("/flags", _dash.flags_page, methods=["GET"]),
+ Route("/flags/rows", _dash.flags_rows_partial, methods=["GET"]),
+ Route("/flags/create", _dash.flag_create_form, methods=["POST"]),
+ Route("/flags/{key}", _dash.flag_detail_page, methods=["GET"]),
+ Route("/flags/{key}/settings/save", _dash.flag_settings_save, methods=["POST"]),
+ Route("/flags/{key}/variations/save", _dash.flag_variations_save, methods=["POST"]),
+ Route("/flags/{key}/targeting/save", _dash.flag_targeting_save, methods=["POST"]),
+ Route("/flags/{key}/prerequisites/save", _dash.flag_prerequisites_save, methods=["POST"]),
+ Route("/flags/{key}/targets/save", _dash.flag_targets_save, methods=["POST"]),
+ Route("/flags/{key}/enable", _dash.flag_enable, methods=["POST"]),
+ Route("/flags/{key}/disable", _dash.flag_disable, methods=["POST"]),
+ Route("/flags/{key}", _dash.flag_delete, methods=["DELETE"]),
+ Route("/flags/{key}/eval", _dash.flag_eval_form, methods=["POST"]),
+ Route("/segments", _dash.segments_page, methods=["GET"]),
+ Route("/segments/rows", _dash.segments_rows_partial, methods=["GET"]),
+ Route("/segments/create", _dash.segment_create_form, methods=["POST"]),
+ Route("/segments/{key}/rules/add", _dash.segment_rule_add, methods=["POST"]),
+ Route("/segments/{key}/rules/{rule_id}", _dash.segment_rule_delete, methods=["DELETE"]),
+ Route("/segments/{key}/save", _dash.segment_save_form, methods=["POST"]),
+ Route("/segments/{key}", _dash.modal_segment_detail, methods=["GET"]),
+ Route("/segments/{key}", _dash.segment_delete, methods=["DELETE"]),
+ ]
+
+
+def _flag_routes() -> list[Route]:
+ """Return the flag + segment API routes for mounting in ShieldAdmin."""
+ return [
+ # ── Flags CRUD ───────────────────────────────────────────────
+ Route("/api/flags", _api.list_flags, methods=["GET"]),
+ Route("/api/flags", _api.create_flag, methods=["POST"]),
+ Route("/api/flags/{key}", _api.get_flag, methods=["GET"]),
+ Route("/api/flags/{key}", _api.update_flag, methods=["PUT"]),
+ Route("/api/flags/{key}", _api.patch_flag, methods=["PATCH"]),
+ Route("/api/flags/{key}", _api.delete_flag, methods=["DELETE"]),
+ Route("/api/flags/{key}/enable", _api.enable_flag, methods=["POST"]),
+ Route("/api/flags/{key}/disable", _api.disable_flag, methods=["POST"]),
+ Route("/api/flags/{key}/evaluate", _api.evaluate_flag, methods=["POST"]),
+ # ── Segments CRUD ────────────────────────────────────────────
+ Route("/api/segments", _api.list_segments, methods=["GET"]),
+ Route("/api/segments", _api.create_segment, methods=["POST"]),
+ Route("/api/segments/{key}", _api.get_segment, methods=["GET"]),
+ Route("/api/segments/{key}", _api.update_segment, methods=["PUT"]),
+ Route("/api/segments/{key}", _api.delete_segment, methods=["DELETE"]),
+ ]
+
+
def ShieldAdmin(
engine: ShieldEngine,
auth: AuthConfig = None,
@@ -218,6 +282,7 @@ def ShieldAdmin(
sdk_token_expiry: int = 31536000,
secret_key: str | None = None,
prefix: str = "/shield",
+ enable_flags: bool | None = None,
) -> ASGIApp:
"""Create the unified Shield admin ASGI app.
@@ -249,6 +314,13 @@ def ShieldAdmin(
prefix:
URL prefix at which the admin app is mounted. Must match the path
passed to ``app.mount()``. Used to build correct redirects.
+ enable_flags:
+ When ``True``, mount the feature flag and segment dashboard UI and
+ REST API endpoints (``/flags/*``, ``/api/flags/*``, ``/api/segments/*``).
+ Requires ``engine.use_openfeature()`` to have been called and
+ ``api-shield[flags]`` to be installed.
+ When ``None`` (default), auto-detected: flags are enabled when
+ ``engine.use_openfeature()`` has been called.
Returns
-------
@@ -257,6 +329,10 @@ def ShieldAdmin(
"""
import base64
+ # Auto-detect flags: enabled when engine.use_openfeature() has been called.
+ if enable_flags is None:
+ enable_flags = getattr(engine, "_flag_client", None) is not None
+
templates = Jinja2Templates(directory=str(_TEMPLATES_DIR))
templates.env.filters["encode_path"] = lambda p: (
base64.urlsafe_b64encode(p.encode()).decode().rstrip("=")
@@ -286,6 +362,7 @@ def _clean_entry_path(entry: object) -> str:
templates.env.filters["clean_path"] = _clean_path
templates.env.filters["clean_entry_path"] = _clean_entry_path
+ templates.env.globals["flags_enabled"] = enable_flags
try:
version = importlib.metadata.version("api-shield")
@@ -316,6 +393,8 @@ def _clean_entry_path(entry: object) -> str:
Route("/modal/global-rl/delete", _dash.modal_global_rl_delete),
Route("/modal/global-rl/reset", _dash.modal_global_rl_reset),
Route("/modal/env/{path_key}", _dash.modal_env_gate),
+ # Flag/segment modals must come before the generic wildcard below.
+ *(_flag_dashboard_modal_routes() if enable_flags else []),
Route("/modal/{action}/{path_key}", _dash.action_modal),
Route(
"/global-maintenance/enable",
@@ -422,6 +501,9 @@ def _clean_entry_path(entry: object) -> str:
Route("/api/sdk/audit", _api.sdk_audit, methods=["POST"]),
# ── Service discovery ────────────────────────────────────────
Route("/api/services", _api.list_services, methods=["GET"]),
+ # ── Feature flags (mounted only when enable_flags=True) ──────
+ *(_flag_dashboard_routes() if enable_flags else []),
+ *(_flag_routes() if enable_flags else []),
],
)
@@ -432,6 +514,7 @@ def _clean_entry_path(entry: object) -> str:
starlette_app.state.version = version
starlette_app.state.token_manager = token_manager
starlette_app.state.auth_backend = auth_backend
+ starlette_app.state.flags_enabled = enable_flags
# Wrap with auth middleware.
return _AuthMiddleware(starlette_app, token_manager=token_manager, auth_backend=auth_backend)
diff --git a/shield/cli/client.py b/shield/cli/client.py
index b3025a7..23e5691 100644
--- a/shield/cli/client.py
+++ b/shield/cli/client.py
@@ -365,6 +365,102 @@ async def disable_global_rate_limit(self) -> dict[str, Any]:
resp = await c.post("/api/global-rate-limit/disable")
return cast(dict[str, Any], self._check(resp))
+ # ── Feature flags ─────────────────────────────────────────────────
+
+ async def list_flags(self) -> list[dict[str, Any]]:
+ """GET /api/flags — list all feature flags."""
+ async with self._make_client() as c:
+ resp = await c.get("/api/flags")
+ return cast(list[dict[str, Any]], self._check(resp))
+
+ async def get_flag(self, key: str) -> dict[str, Any]:
+ """GET /api/flags/{key} — get a single feature flag."""
+ async with self._make_client() as c:
+ resp = await c.get(f"/api/flags/{key}")
+ return cast(dict[str, Any], self._check(resp))
+
+ async def create_flag(self, flag_data: dict[str, Any]) -> dict[str, Any]:
+ """POST /api/flags — create a new feature flag."""
+ async with self._make_client() as c:
+ resp = await c.post("/api/flags", json=flag_data)
+ return cast(dict[str, Any], self._check(resp))
+
+ async def update_flag(self, key: str, flag_data: dict[str, Any]) -> dict[str, Any]:
+ """PUT /api/flags/{key} — replace a feature flag."""
+ async with self._make_client() as c:
+ resp = await c.put(f"/api/flags/{key}", json=flag_data)
+ return cast(dict[str, Any], self._check(resp))
+
+ async def patch_flag(self, key: str, patch: dict[str, Any]) -> dict[str, Any]:
+ """PATCH /api/flags/{key} — partial update."""
+ async with self._make_client() as c:
+ resp = await c.patch(f"/api/flags/{key}", json=patch)
+ return cast(dict[str, Any], self._check(resp))
+
+ async def enable_flag(self, key: str) -> dict[str, Any]:
+ """POST /api/flags/{key}/enable — enable a feature flag."""
+ async with self._make_client() as c:
+ resp = await c.post(f"/api/flags/{key}/enable")
+ return cast(dict[str, Any], self._check(resp))
+
+ async def disable_flag(self, key: str) -> dict[str, Any]:
+ """POST /api/flags/{key}/disable — disable a feature flag."""
+ async with self._make_client() as c:
+ resp = await c.post(f"/api/flags/{key}/disable")
+ return cast(dict[str, Any], self._check(resp))
+
+ async def delete_flag(self, key: str) -> dict[str, Any]:
+ """DELETE /api/flags/{key} — delete a feature flag."""
+ async with self._make_client() as c:
+ resp = await c.delete(f"/api/flags/{key}")
+ return cast(dict[str, Any], self._check(resp))
+
+ async def evaluate_flag(
+ self,
+ key: str,
+ context: dict[str, Any],
+ default: Any = None,
+ ) -> dict[str, Any]:
+ """POST /api/flags/{key}/evaluate — evaluate a flag for a context."""
+ async with self._make_client() as c:
+ resp = await c.post(
+ f"/api/flags/{key}/evaluate",
+ json={"context": context, "default": default},
+ )
+ return cast(dict[str, Any], self._check(resp))
+
+ # ── Segments ──────────────────────────────────────────────────────
+
+ async def list_segments(self) -> list[dict[str, Any]]:
+ """GET /api/segments — list all segments."""
+ async with self._make_client() as c:
+ resp = await c.get("/api/segments")
+ return cast(list[dict[str, Any]], self._check(resp))
+
+ async def get_segment(self, key: str) -> dict[str, Any]:
+ """GET /api/segments/{key} — get a single segment."""
+ async with self._make_client() as c:
+ resp = await c.get(f"/api/segments/{key}")
+ return cast(dict[str, Any], self._check(resp))
+
+ async def create_segment(self, segment_data: dict[str, Any]) -> dict[str, Any]:
+ """POST /api/segments — create a new segment."""
+ async with self._make_client() as c:
+ resp = await c.post("/api/segments", json=segment_data)
+ return cast(dict[str, Any], self._check(resp))
+
+ async def update_segment(self, key: str, segment_data: dict[str, Any]) -> dict[str, Any]:
+ """PUT /api/segments/{key} — replace a segment."""
+ async with self._make_client() as c:
+ resp = await c.put(f"/api/segments/{key}", json=segment_data)
+ return cast(dict[str, Any], self._check(resp))
+
+ async def delete_segment(self, key: str) -> dict[str, Any]:
+ """DELETE /api/segments/{key} — delete a segment."""
+ async with self._make_client() as c:
+ resp = await c.delete(f"/api/segments/{key}")
+ return cast(dict[str, Any], self._check(resp))
+
def make_client(
transport: httpx.AsyncBaseTransport | None = None,
diff --git a/shield/cli/main.py b/shield/cli/main.py
index fb3db91..f7ee7c0 100644
--- a/shield/cli/main.py
+++ b/shield/cli/main.py
@@ -1340,5 +1340,967 @@ async def _run_grl_disable() -> None:
_run(_run_grl_disable)
+# ---------------------------------------------------------------------------
+# Feature flags command group (shield flags ...)
+# ---------------------------------------------------------------------------
+
+_FLAG_TYPE_COLOURS = {
+ "boolean": "green",
+ "string": "cyan",
+ "integer": "blue",
+ "float": "blue",
+ "json": "magenta",
+}
+
+flags_app = typer.Typer(
+ name="flags",
+ help="Manage feature flags.",
+ no_args_is_help=True,
+)
+cli.add_typer(flags_app, name="flags")
+
+
+def _flag_status_colour(enabled: bool) -> str:
+ return "green" if enabled else "dim"
+
+
+def _print_flags_table(flags: list[dict[str, Any]]) -> None:
+ tbl = Table(box=box.SIMPLE_HEAD, show_edge=False, pad_edge=False)
+ tbl.add_column("Key", style="bold cyan", no_wrap=True)
+ tbl.add_column("Type", style="white")
+ tbl.add_column("Status", style="white")
+ tbl.add_column("Variations", style="dim")
+ tbl.add_column("Fallthrough", style="dim")
+ for f in flags:
+ enabled = f.get("enabled", True)
+ status_text = "[green]enabled[/green]" if enabled else "[dim]disabled[/dim]"
+ ftype = f.get("type", "")
+ colour = _FLAG_TYPE_COLOURS.get(ftype, "white")
+ variations = ", ".join(v["name"] for v in f.get("variations", []))
+ fallthrough = f.get("fallthrough", "")
+ if isinstance(fallthrough, list):
+ fallthrough = "rollout"
+ tbl.add_row(
+ f.get("key", ""),
+ f"[{colour}]{ftype}[/{colour}]",
+ status_text,
+ variations,
+ str(fallthrough),
+ )
+ console.print(tbl)
+
+
+@flags_app.command("list")
+def flags_list(
+ type: str = typer.Option("", "--type", "-t", help="Filter by flag type (boolean, string, …)"),
+ enabled: str = typer.Option("", "--status", "-s", help="Filter by status: enabled or disabled"),
+) -> None:
+ """List all feature flags."""
+
+ async def _run_flags_list() -> None:
+ flags = await make_client().list_flags()
+ if type:
+ flags = [f for f in flags if f.get("type") == type]
+ if enabled == "enabled":
+ flags = [f for f in flags if f.get("enabled", True)]
+ elif enabled == "disabled":
+ flags = [f for f in flags if not f.get("enabled", True)]
+ if not flags:
+ console.print("[dim]No flags found.[/dim]")
+ return
+ _print_flags_table(flags)
+ console.print(f"[dim]{len(flags)} flag(s)[/dim]")
+
+ _run(_run_flags_list)
+
+
+@flags_app.command("get")
+def flags_get(key: str = typer.Argument(..., help="Flag key")) -> None:
+ """Show details for a single feature flag."""
+
+ async def _run_flags_get() -> None:
+ flag = await make_client().get_flag(key)
+ console.print(f"[bold cyan]{flag['key']}[/bold cyan] [dim]{flag.get('name', '')}[/dim]")
+ ftype = flag.get("type", "")
+ colour = _FLAG_TYPE_COLOURS.get(ftype, "white")
+ enabled = flag.get("enabled", True)
+ status_text = "[green]enabled[/green]" if enabled else "[dim]disabled[/dim]"
+ console.print(f" Type: [{colour}]{ftype}[/{colour}]")
+ console.print(f" Status: {status_text}")
+ console.print(f" Off variation: [dim]{flag.get('off_variation', '')}[/dim]")
+ fallthrough = flag.get("fallthrough", "")
+ if isinstance(fallthrough, list):
+ parts = [f"{rv['variation']}:{rv['weight'] // 1000}%" for rv in fallthrough]
+ console.print(f" Fallthrough: [dim]{', '.join(parts)}[/dim]")
+ else:
+ console.print(f" Fallthrough: [dim]{fallthrough}[/dim]")
+ # Variations
+ console.print(" Variations:")
+ for v in flag.get("variations", []):
+ console.print(f" • [bold]{v['name']}[/bold] = {v['value']!r}")
+ # Rules
+ rules = flag.get("rules") or []
+ if rules:
+ console.print(f" Rules: [dim]{len(rules)} targeting rule(s)[/dim]")
+ # Prerequisites
+ prereqs = flag.get("prerequisites") or []
+ if prereqs:
+ console.print(" Prerequisites:")
+ for p in prereqs:
+ console.print(
+ f" • [cyan]{p['flag_key']}[/cyan] must be [bold]{p['variation']}[/bold]"
+ )
+
+ _run(_run_flags_get)
+
+
+@flags_app.command("create")
+def flags_create(
+ key: str = typer.Argument(..., help="Unique flag key (e.g. new_checkout)"),
+ name: str = typer.Option(..., "--name", "-n", help="Human-readable name"),
+ type: str = typer.Option(
+ "boolean", "--type", "-t", help="Flag type: boolean, string, integer, float, json"
+ ),
+ description: str = typer.Option("", "--description", "-d", help="Optional description"),
+) -> None:
+ """Create a new boolean feature flag with on/off variations.
+
+ For other types or advanced configuration, use the dashboard or the API
+ directly. The flag is created enabled with fallthrough=off.
+
+ \b
+ shield flags create new_checkout --name "New Checkout Flow"
+ shield flags create dark_mode --name "Dark Mode" --type boolean
+ """
+
+ async def _run_flags_create() -> None:
+ flag_type = type.lower()
+ # Build default on/off variations based on type.
+ if flag_type == "boolean":
+ variations = [{"name": "on", "value": True}, {"name": "off", "value": False}]
+ off_variation = "off"
+ fallthrough = "off"
+ elif flag_type == "string":
+ variations = [
+ {"name": "control", "value": "control"},
+ {"name": "treatment", "value": "treatment"},
+ ]
+ off_variation = "control"
+ fallthrough = "control"
+ elif flag_type in ("integer", "float"):
+ variations = [{"name": "off", "value": 0}, {"name": "on", "value": 1}]
+ off_variation = "off"
+ fallthrough = "off"
+ elif flag_type == "json":
+ variations = [{"name": "off", "value": {}}, {"name": "on", "value": {}}]
+ off_variation = "off"
+ fallthrough = "off"
+ else:
+ err_console.print(
+ f"[red]Error:[/red] Unknown type {type!r}. "
+ "Use boolean, string, integer, float, or json."
+ )
+ raise typer.Exit(code=1)
+
+ flag_data = {
+ "key": key,
+ "name": name,
+ "type": flag_type,
+ "description": description,
+ "variations": variations,
+ "off_variation": off_variation,
+ "fallthrough": fallthrough,
+ "enabled": True,
+ }
+ result = await make_client().create_flag(flag_data)
+ console.print(f"[green]✓[/green] Flag [bold cyan]{result['key']}[/bold cyan] created.")
+
+ _run(_run_flags_create)
+
+
+@flags_app.command("enable")
+def flags_enable(key: str = typer.Argument(..., help="Flag key")) -> None:
+ """Enable a feature flag."""
+
+ async def _run_flags_enable() -> None:
+ result = await make_client().enable_flag(key)
+ console.print(f"[green]✓[/green] Flag [bold cyan]{result['key']}[/bold cyan] enabled.")
+
+ _run(_run_flags_enable)
+
+
+@flags_app.command("disable")
+def flags_disable(key: str = typer.Argument(..., help="Flag key")) -> None:
+ """Disable a feature flag (serves the off variation to all users)."""
+
+ async def _run_flags_disable() -> None:
+ result = await make_client().disable_flag(key)
+ console.print(f"[dim]✓ Flag {result['key']} disabled.[/dim]")
+
+ _run(_run_flags_disable)
+
+
+@flags_app.command("delete")
+def flags_delete(
+ key: str = typer.Argument(..., help="Flag key"),
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
+) -> None:
+ """Permanently delete a feature flag."""
+ if not yes:
+ typer.confirm(f"Delete flag '{key}'? This cannot be undone.", abort=True)
+
+ async def _run_flags_delete() -> None:
+ result = await make_client().delete_flag(key)
+ console.print(f"[green]✓[/green] Flag [bold]{result['deleted']}[/bold] deleted.")
+
+ _run(_run_flags_delete)
+
+
+@flags_app.command("eval")
+def flags_eval(
+ key: str = typer.Argument(..., help="Flag key"),
+ ctx_key: str = typer.Option("anonymous", "--key", "-k", help="Context key (user ID)"),
+ kind: str = typer.Option("user", "--kind", help="Context kind"),
+ attr: list[str] = typer.Option([], "--attr", "-a", help="Attribute as key=value (repeatable)"),
+) -> None:
+ """Evaluate a feature flag for a given context (debug tool).
+
+ \b
+ shield flags eval new_checkout --key user_123 --attr role=admin --attr plan=pro
+ """
+
+ async def _run_flags_eval() -> None:
+ attributes: dict[str, str] = {}
+ for a in attr:
+ if "=" not in a:
+ err_console.print(f"[red]Error:[/red] Attribute must be key=value, got: {a!r}")
+ raise typer.Exit(code=1)
+ k, _, v = a.partition("=")
+ attributes[k.strip()] = v.strip()
+
+ context = {"key": ctx_key, "kind": kind, "attributes": attributes}
+ result = await make_client().evaluate_flag(key, context)
+
+ value = result.get("value")
+ variation = result.get("variation", "")
+ reason = result.get("reason", "")
+ rule_id = result.get("rule_id")
+
+ tbl = Table(box=box.SIMPLE_HEAD, show_edge=False, pad_edge=False, show_header=False)
+ tbl.add_column("Field", style="dim", no_wrap=True)
+ tbl.add_column("Value", style="bold")
+ tbl.add_row("value", str(value))
+ tbl.add_row("variation", variation or "—")
+ tbl.add_row("reason", reason)
+ if rule_id:
+ tbl.add_row("rule_id", rule_id)
+ prereq = result.get("prerequisite_key")
+ if prereq:
+ tbl.add_row("prerequisite_key", prereq)
+ err_msg = result.get("error_message")
+ if err_msg:
+ tbl.add_row("error", f"[red]{err_msg}[/red]")
+ console.print(tbl)
+
+ _run(_run_flags_eval)
+
+
+@flags_app.command("edit")
+def flags_edit(
+ key: str = typer.Argument(..., help="Flag key"),
+ name: str | None = typer.Option(None, "--name", "-n", help="New display name"),
+ description: str | None = typer.Option(None, "--description", "-d", help="New description"),
+ off_variation: str | None = typer.Option(
+ None, "--off-variation", help="Variation served when flag is disabled"
+ ),
+ fallthrough: str | None = typer.Option(
+ None, "--fallthrough", help="Default variation when no rule matches"
+ ),
+) -> None:
+ """Patch a feature flag (partial update — only provided fields are changed).
+
+ \b
+ shield flags edit dark_mode --name "Dark Mode v2"
+ shield flags edit dark_mode --off-variation off --fallthrough control
+ """
+
+ async def _run_flags_edit() -> None:
+ patch: dict[str, Any] = {}
+ if name is not None:
+ patch["name"] = name
+ if description is not None:
+ patch["description"] = description
+ if off_variation is not None:
+ patch["off_variation"] = off_variation
+ if fallthrough is not None:
+ patch["fallthrough"] = fallthrough
+ if not patch:
+ err_console.print("[yellow]Nothing to update — provide at least one option.[/yellow]")
+ raise typer.Exit(1)
+ result = await make_client().patch_flag(key, patch)
+ console.print(f"[green]✓[/green] Flag [bold cyan]{result['key']}[/bold cyan] updated.")
+ tbl = Table(box=box.SIMPLE_HEAD, show_edge=False, pad_edge=False, show_header=False)
+ tbl.add_column("Field", style="dim", no_wrap=True)
+ tbl.add_column("Value", style="bold")
+ for field in ("name", "description", "off_variation", "fallthrough"):
+ if field in patch:
+ val = result.get(field)
+ tbl.add_row(field, str(val) if val is not None else "—")
+ console.print(tbl)
+
+ _run(_run_flags_edit)
+
+
+@flags_app.command("variations")
+def flags_variations(key: str = typer.Argument(..., help="Flag key")) -> None:
+ """List variations for a feature flag."""
+
+ async def _run_flags_variations() -> None:
+ flag = await make_client().get_flag(key)
+ variations = flag.get("variations") or []
+ if not variations:
+ console.print(f"[dim]No variations for flag '{key}'.[/dim]")
+ return
+ tbl = Table(box=box.ROUNDED, show_header=True, header_style="bold")
+ tbl.add_column("Name", style="bold cyan", no_wrap=True)
+ tbl.add_column("Value", style="white")
+ tbl.add_column("Description", style="dim")
+ tbl.add_column("Role", style="dim")
+ off_var = flag.get("off_variation", "")
+ fallthrough = flag.get("fallthrough")
+ for v in variations:
+ vname = v.get("name", "")
+ role = ""
+ if vname == off_var:
+ role = "[slate]off[/slate]"
+ elif isinstance(fallthrough, str) and vname == fallthrough:
+ role = "[magenta]fallthrough[/magenta]"
+ tbl.add_row(vname, str(v.get("value", "")), v.get("description") or "—", role)
+ console.print(f"[bold cyan]{flag['key']}[/bold cyan] [dim]{flag.get('type', '')}[/dim]")
+ console.print(tbl)
+
+ _run(_run_flags_variations)
+
+
+@flags_app.command("targeting")
+def flags_targeting(key: str = typer.Argument(..., help="Flag key")) -> None:
+ """Show targeting rules for a feature flag (read-only view)."""
+
+ async def _run_flags_targeting() -> None:
+ flag = await make_client().get_flag(key)
+ rules = flag.get("rules") or []
+
+ off_var = flag.get("off_variation", "—")
+ ft = flag.get("fallthrough", "—")
+ console.print(
+ f"[bold cyan]{flag['key']}[/bold cyan]"
+ f" off=[cyan]{off_var}[/cyan]"
+ f" fallthrough=[cyan]{ft}[/cyan]"
+ )
+
+ if not rules:
+ console.print("[dim]No targeting rules.[/dim]")
+ return
+
+ for i, rule in enumerate(rules):
+ desc = rule.get("description") or ""
+ variation = rule.get("variation") or "—"
+ clauses = rule.get("clauses") or []
+ console.print(
+ f"\n [bold]Rule {i + 1}[/bold]"
+ + (f" — {desc}" if desc else "")
+ + f" → [green]{variation}[/green]"
+ )
+ console.print(f" [dim]id: {rule.get('id', '')}[/dim]")
+ for clause in clauses:
+ attr = clause.get("attribute", "")
+ op = clause.get("operator", "")
+ vals = clause.get("values") or []
+ negate = clause.get("negate", False)
+ neg_str = "[dim]NOT[/dim] " if negate else ""
+ vals_str = ", ".join(str(v) for v in vals)
+ console.print(f" {neg_str}[cyan]{attr}[/cyan] [dim]{op}[/dim] {vals_str}")
+
+ _run(_run_flags_targeting)
+
+
+@flags_app.command("add-rule")
+def flags_add_rule(
+ key: str = typer.Argument(..., help="Flag key"),
+ variation: str = typer.Option(
+ ..., "--variation", "-v", help="Variation to serve when rule matches"
+ ),
+ segment: str | None = typer.Option(
+ None, "--segment", "-s", help="Segment key (adds an in_segment clause)"
+ ),
+ attribute: str | None = typer.Option(
+ None, "--attribute", "-a", help="Attribute name for a custom clause"
+ ),
+ operator: str = typer.Option(
+ "is", "--operator", "-o", help="Operator (e.g. is, in_segment, contains)"
+ ),
+ values: str | None = typer.Option(None, "--values", help="Comma-separated clause values"),
+ description: str = typer.Option("", "--description", "-d", help="Optional rule description"),
+ negate: bool = typer.Option(False, "--negate", help="Negate the clause result"),
+) -> None:
+ """Add a targeting rule to a feature flag.
+
+ \b
+ Segment-based rule (most common):
+ shield flags add-rule my-flag --variation on --segment beta-users
+
+ Custom attribute rule:
+ shield flags add-rule my-flag --variation on \
+ --attribute plan --operator is --values pro,enterprise
+ """
+ if segment is None and attribute is None:
+ console.print("[red]Error:[/red] provide --segment or --attribute.")
+ raise typer.Exit(1)
+ if segment is not None and attribute is not None:
+ console.print("[red]Error:[/red] --segment and --attribute are mutually exclusive.")
+ raise typer.Exit(1)
+
+ async def _run_add_rule() -> None:
+ client = make_client()
+ flag = await client.get_flag(key)
+ rules = list(flag.get("rules") or [])
+
+ if segment is not None:
+ clause = {
+ "attribute": "key",
+ "operator": "in_segment",
+ "values": [segment],
+ "negate": negate,
+ }
+ else:
+ raw_vals: list[Any] = [v.strip() for v in (values or "").split(",") if v.strip()]
+ clause = {
+ "attribute": attribute,
+ "operator": operator,
+ "values": raw_vals,
+ "negate": negate,
+ }
+
+ import uuid as _uuid
+
+ new_rule: dict[str, Any] = {
+ "id": str(_uuid.uuid4()),
+ "description": description,
+ "clauses": [clause],
+ "variation": variation,
+ }
+ rules.append(new_rule)
+ await client.patch_flag(key, {"rules": rules})
+ clause_summary = (
+ f"in_segment [cyan]{segment}[/cyan]"
+ if segment is not None
+ else f"[cyan]{attribute}[/cyan] [dim]{operator}[/dim] {values}"
+ )
+ console.print(
+ f"[green]✓[/green] Rule added to [bold cyan]{key}[/bold cyan]: "
+ f"{clause_summary} → [green]{variation}[/green]"
+ )
+ console.print(f" [dim]id: {new_rule['id']}[/dim]")
+
+ _run(_run_add_rule)
+
+
+@flags_app.command("remove-rule")
+def flags_remove_rule(
+ key: str = typer.Argument(..., help="Flag key"),
+ rule_id: str = typer.Option(..., "--rule-id", "-r", help="Rule ID to remove"),
+) -> None:
+ """Remove a targeting rule from a feature flag by its ID.
+
+ \b
+ shield flags remove-rule my-flag --rule-id
+
+ Use 'shield flags targeting my-flag' to list rule IDs.
+ """
+
+ async def _run_remove_rule() -> None:
+ client = make_client()
+ flag = await client.get_flag(key)
+ rules = list(flag.get("rules") or [])
+ original_len = len(rules)
+ rules = [r for r in rules if r.get("id") != rule_id]
+ if len(rules) == original_len:
+ console.print(f"[red]Error:[/red] no rule with id '{rule_id}' found on flag '{key}'.")
+ raise typer.Exit(1)
+ await client.patch_flag(key, {"rules": rules})
+ console.print(
+ f"[green]✓[/green] Rule [dim]{rule_id}[/dim] removed from [bold cyan]{key}[/bold cyan]."
+ )
+
+ _run(_run_remove_rule)
+
+
+# ---------------------------------------------------------------------------
+# Prerequisites commands (shield flags add-prereq / remove-prereq)
+# ---------------------------------------------------------------------------
+
+
+@flags_app.command("add-prereq")
+def flags_add_prereq(
+ key: str = typer.Argument(..., help="Flag key"),
+ prereq_flag: str = typer.Option(..., "--flag", "-f", help="Prerequisite flag key"),
+ variation: str = typer.Option(
+ ..., "--variation", "-v", help="Variation the prerequisite flag must return"
+ ),
+) -> None:
+ """Add a prerequisite flag to a feature flag.
+
+ \b
+ shield flags add-prereq my-flag --flag auth-flag --variation on
+
+ The prerequisite flag must evaluate to the given variation before this
+ flag's rules run. If it doesn't, this flag serves its off_variation.
+ """
+
+ async def _run_add_prereq() -> None:
+ client = make_client()
+ flag = await client.get_flag(key)
+ if flag["key"] == prereq_flag:
+ console.print("[red]Error:[/red] a flag cannot be its own prerequisite.")
+ raise typer.Exit(1)
+ prereqs = list(flag.get("prerequisites") or [])
+ # avoid duplicates
+ for p in prereqs:
+ if p.get("flag_key") == prereq_flag:
+ console.print(
+ f"[yellow]Warning:[/yellow] prerequisite [cyan]{prereq_flag}[/cyan]"
+ " already exists. Updating variation."
+ )
+ p["variation"] = variation
+ await client.patch_flag(key, {"prerequisites": prereqs})
+ console.print(
+ f"[green]✓[/green] Prerequisite [cyan]{prereq_flag}[/cyan]"
+ f" updated → must be [green]{variation}[/green]."
+ )
+ return
+ prereqs.append({"flag_key": prereq_flag, "variation": variation})
+ await client.patch_flag(key, {"prerequisites": prereqs})
+ console.print(
+ f"[green]✓[/green] Prerequisite [cyan]{prereq_flag}[/cyan]"
+ f" added to [bold cyan]{key}[/bold cyan]:"
+ f" must be [green]{variation}[/green]."
+ )
+
+ _run(_run_add_prereq)
+
+
+@flags_app.command("remove-prereq")
+def flags_remove_prereq(
+ key: str = typer.Argument(..., help="Flag key"),
+ prereq_flag: str = typer.Option(..., "--flag", "-f", help="Prerequisite flag key to remove"),
+) -> None:
+ """Remove a prerequisite from a feature flag.
+
+ \b
+ shield flags remove-prereq my-flag --flag auth-flag
+ """
+
+ async def _run_remove_prereq() -> None:
+ client = make_client()
+ flag = await client.get_flag(key)
+ prereqs = list(flag.get("prerequisites") or [])
+ original_len = len(prereqs)
+ prereqs = [p for p in prereqs if p.get("flag_key") != prereq_flag]
+ if len(prereqs) == original_len:
+ console.print(
+ f"[red]Error:[/red] prerequisite [cyan]{prereq_flag}[/cyan]"
+ f" not found on flag [cyan]{key}[/cyan]."
+ )
+ raise typer.Exit(1)
+ await client.patch_flag(key, {"prerequisites": prereqs})
+ console.print(
+ f"[green]✓[/green] Prerequisite [cyan]{prereq_flag}[/cyan]"
+ f" removed from [bold cyan]{key}[/bold cyan]."
+ )
+
+ _run(_run_remove_prereq)
+
+
+# ---------------------------------------------------------------------------
+# Individual targets commands (shield flags target / untarget)
+# ---------------------------------------------------------------------------
+
+
+@flags_app.command("target")
+def flags_target(
+ key: str = typer.Argument(..., help="Flag key"),
+ variation: str = typer.Option(
+ ..., "--variation", "-v", help="Variation to serve to the context keys"
+ ),
+ context_keys: str = typer.Option(
+ ..., "--keys", "-k", help="Comma-separated context keys to pin"
+ ),
+) -> None:
+ """Pin context keys to always receive a specific variation.
+
+ \b
+ shield flags target my-flag --variation on --keys user_123,user_456
+
+ Individual targets are evaluated before rules — highest priority targeting.
+ """
+
+ async def _run_target() -> None:
+ client = make_client()
+ flag = await client.get_flag(key)
+ variation_names = [v["name"] for v in (flag.get("variations") or [])]
+ if variation not in variation_names:
+ console.print(
+ f"[red]Error:[/red] variation [cyan]{variation}[/cyan] not found."
+ f" Available: {', '.join(variation_names)}"
+ )
+ raise typer.Exit(1)
+ new_keys = [k.strip() for k in context_keys.split(",") if k.strip()]
+ targets: dict[str, Any] = dict(flag.get("targets") or {})
+ existing = list(targets.get(variation, []))
+ added = [k for k in new_keys if k not in existing]
+ existing.extend(added)
+ targets[variation] = existing
+ await client.patch_flag(key, {"targets": targets})
+ console.print(
+ f"[green]✓[/green] Added {len(added)} key(s)"
+ f" to [bold cyan]{key}[/bold cyan] → [green]{variation}[/green]."
+ )
+
+ _run(_run_target)
+
+
+@flags_app.command("untarget")
+def flags_untarget(
+ key: str = typer.Argument(..., help="Flag key"),
+ variation: str = typer.Option(
+ ..., "--variation", "-v", help="Variation to remove context keys from"
+ ),
+ context_keys: str = typer.Option(
+ ..., "--keys", "-k", help="Comma-separated context keys to unpin"
+ ),
+) -> None:
+ """Remove context keys from individual targeting.
+
+ \b
+ shield flags untarget my-flag --variation on --keys user_123
+ """
+
+ async def _run_untarget() -> None:
+ client = make_client()
+ flag = await client.get_flag(key)
+ remove_keys = {k.strip() for k in context_keys.split(",") if k.strip()}
+ targets: dict[str, Any] = dict(flag.get("targets") or {})
+ existing = list(targets.get(variation, []))
+ if not existing:
+ console.print(
+ f"[yellow]Warning:[/yellow] no targets for variation [cyan]{variation}[/cyan]."
+ )
+ raise typer.Exit(1)
+ updated = [k for k in existing if k not in remove_keys]
+ if updated:
+ targets[variation] = updated
+ else:
+ targets.pop(variation, None)
+ await client.patch_flag(key, {"targets": targets})
+ removed = len(existing) - len(updated)
+ console.print(
+ f"[green]✓[/green] Removed {removed} key(s)"
+ f" from [bold cyan]{key}[/bold cyan] → [cyan]{variation}[/cyan]."
+ )
+
+ _run(_run_untarget)
+
+
+# ---------------------------------------------------------------------------
+# Segments command group (shield segments ...)
+# ---------------------------------------------------------------------------
+
+segments_app = typer.Typer(
+ name="segments",
+ help="Manage targeting segments.",
+ no_args_is_help=True,
+)
+cli.add_typer(segments_app, name="segments")
+cli.add_typer(segments_app, name="seg")
+
+
+def _print_segments_table(segments: list[dict[str, Any]]) -> None:
+ tbl = Table(box=box.SIMPLE_HEAD, show_edge=False, pad_edge=False)
+ tbl.add_column("Key", style="bold cyan", no_wrap=True)
+ tbl.add_column("Name", style="white")
+ tbl.add_column("Included", style="green")
+ tbl.add_column("Excluded", style="red")
+ tbl.add_column("Rules", style="dim")
+ for s in segments:
+ included = s.get("included") or []
+ excluded = s.get("excluded") or []
+ rules = s.get("rules") or []
+ tbl.add_row(
+ s.get("key", ""),
+ s.get("name", ""),
+ str(len(included)),
+ str(len(excluded)),
+ str(len(rules)),
+ )
+ console.print(tbl)
+
+
+@segments_app.command("list")
+def segments_list() -> None:
+ """List all targeting segments."""
+
+ async def _run_segments_list() -> None:
+ segments = await make_client().list_segments()
+ if not segments:
+ console.print("[dim]No segments found.[/dim]")
+ return
+ _print_segments_table(segments)
+ console.print(f"[dim]{len(segments)} segment(s)[/dim]")
+
+ _run(_run_segments_list)
+
+
+@segments_app.command("get")
+def segments_get(key: str = typer.Argument(..., help="Segment key")) -> None:
+ """Show details for a single segment."""
+
+ async def _run_segments_get() -> None:
+ seg = await make_client().get_segment(key)
+ console.print(f"[bold cyan]{seg['key']}[/bold cyan] [dim]{seg.get('name', '')}[/dim]")
+ included = seg.get("included") or []
+ excluded = seg.get("excluded") or []
+ rules = seg.get("rules") or []
+ if included:
+ console.print(
+ f" Included ({len(included)}): [green]{', '.join(included[:10])}[/green]"
+ + (" …" if len(included) > 10 else "")
+ )
+ if excluded:
+ console.print(
+ f" Excluded ({len(excluded)}): [red]{', '.join(excluded[:10])}[/red]"
+ + (" …" if len(excluded) > 10 else "")
+ )
+ if rules:
+ console.print(f" Rules: [dim]{len(rules)} targeting rule(s)[/dim]")
+ if not included and not excluded and not rules:
+ console.print(" [dim](empty segment)[/dim]")
+
+ _run(_run_segments_get)
+
+
+@segments_app.command("create")
+def segments_create(
+ key: str = typer.Argument(..., help="Unique segment key"),
+ name: str = typer.Option(..., "--name", "-n", help="Human-readable segment name"),
+ description: str = typer.Option("", "--description", "-d", help="Optional description"),
+) -> None:
+ """Create a new targeting segment.
+
+ \b
+ shield segments create beta_users --name "Beta Users"
+ """
+
+ async def _run_segments_create() -> None:
+ segment_data = {
+ "key": key,
+ "name": name,
+ "description": description,
+ "included": [],
+ "excluded": [],
+ "rules": [],
+ }
+ result = await make_client().create_segment(segment_data)
+ console.print(f"[green]✓[/green] Segment [bold cyan]{result['key']}[/bold cyan] created.")
+
+ _run(_run_segments_create)
+
+
+@segments_app.command("delete")
+def segments_delete(
+ key: str = typer.Argument(..., help="Segment key"),
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"),
+) -> None:
+ """Permanently delete a targeting segment."""
+ if not yes:
+ typer.confirm(f"Delete segment '{key}'? This cannot be undone.", abort=True)
+
+ async def _run_segments_delete() -> None:
+ result = await make_client().delete_segment(key)
+ console.print(f"[green]✓[/green] Segment [bold]{result['deleted']}[/bold] deleted.")
+
+ _run(_run_segments_delete)
+
+
+@segments_app.command("include")
+def segments_include(
+ key: str = typer.Argument(..., help="Segment key"),
+ context_key: str = typer.Option(
+ ...,
+ "--context-key",
+ "-k",
+ help="Comma-separated context keys to add to the included list",
+ ),
+) -> None:
+ """Add context keys to the segment's included list.
+
+ \b
+ shield segments include beta_users --context-key user_123,user_456
+ """
+
+ async def _run_segments_include() -> None:
+ new_keys = [k.strip() for k in context_key.split(",") if k.strip()]
+ seg = await make_client().get_segment(key)
+ included = list(seg.get("included") or [])
+ added = [k for k in new_keys if k not in included]
+ included.extend(added)
+ seg["included"] = included
+ await make_client().update_segment(key, seg)
+ console.print(
+ f"[green]✓[/green] Added {len(added)} key(s) to [bold cyan]{key}[/bold cyan] "
+ f"included list."
+ )
+
+ _run(_run_segments_include)
+
+
+@segments_app.command("exclude")
+def segments_exclude(
+ key: str = typer.Argument(..., help="Segment key"),
+ context_key: str = typer.Option(
+ ...,
+ "--context-key",
+ "-k",
+ help="Comma-separated context keys to add to the excluded list",
+ ),
+) -> None:
+ """Add context keys to the segment's excluded list.
+
+ \b
+ shield segments exclude beta_users --context-key user_789
+ """
+
+ async def _run_segments_exclude() -> None:
+ new_keys = [k.strip() for k in context_key.split(",") if k.strip()]
+ seg = await make_client().get_segment(key)
+ excluded = list(seg.get("excluded") or [])
+ added = [k for k in new_keys if k not in excluded]
+ excluded.extend(added)
+ seg["excluded"] = excluded
+ await make_client().update_segment(key, seg)
+ console.print(
+ f"[green]✓[/green] Added {len(added)} key(s) to [bold cyan]{key}[/bold cyan] "
+ f"excluded list."
+ )
+
+ _run(_run_segments_exclude)
+
+
+@segments_app.command("add-rule")
+def segments_add_rule(
+ key: str = typer.Argument(..., help="Segment key"),
+ attribute: str = typer.Option(
+ ...,
+ "--attribute",
+ "-a",
+ help="Context attribute (e.g. plan, country)",
+ ),
+ operator: str = typer.Option(
+ "is",
+ "--operator",
+ "-o",
+ help="Operator (e.g. is, in, contains, in_segment)",
+ ),
+ values: str = typer.Option(
+ ...,
+ "--values",
+ "-V",
+ help="Comma-separated values to compare against",
+ ),
+ description: str = typer.Option("", "--description", "-d", help="Optional rule description"),
+ negate: bool = typer.Option(False, "--negate", help="Negate the clause result"),
+) -> None:
+ """Add an attribute-based targeting rule to a segment.
+
+ \b
+ Users matching ANY rule are included in the segment.
+ Multiple clauses within one rule are AND-ed together.
+
+ \b
+ Examples:
+ shield segments add-rule beta_users --attribute plan --operator in --values pro,enterprise
+ shield segments add-rule beta_users --attribute country --operator is --values GB
+ shield segments add-rule beta_users --attribute email --operator ends_with \\
+ --values @acme.com --description "Acme staff"
+ """
+
+ async def _run_add_rule() -> None:
+ import uuid as _uuid
+
+ client = make_client()
+ seg = await client.get_segment(key)
+ rules = list(seg.get("rules") or [])
+
+ # For segment operators the attribute defaults to "key"
+ attr = "key" if operator in ("in_segment", "not_in_segment") else attribute
+ raw_vals: list[Any] = [v.strip() for v in values.split(",") if v.strip()]
+ clause: dict[str, Any] = {
+ "attribute": attr,
+ "operator": operator,
+ "values": raw_vals,
+ "negate": negate,
+ }
+ new_rule: dict[str, Any] = {
+ "id": str(_uuid.uuid4()),
+ "clauses": [clause],
+ }
+ if description:
+ new_rule["description"] = description
+ rules.append(new_rule)
+ seg["rules"] = rules
+ await client.update_segment(key, seg)
+
+ clause_summary = f"[cyan]{attr}[/cyan] [dim]{operator}[/dim] {values}"
+ console.print(
+ f"[green]✓[/green] Rule added to segment [bold cyan]{key}[/bold cyan]: {clause_summary}"
+ )
+ console.print(f" [dim]id: {new_rule['id']}[/dim]")
+
+ _run(_run_add_rule)
+
+
+@segments_app.command("remove-rule")
+def segments_remove_rule(
+ key: str = typer.Argument(..., help="Segment key"),
+ rule_id: str = typer.Option(..., "--rule-id", "-r", help="Rule ID to remove"),
+) -> None:
+ """Remove a targeting rule from a segment by its ID.
+
+ \b
+ shield segments remove-rule beta_users --rule-id
+
+ Use 'shield segments get beta_users' to list rule IDs.
+ """
+
+ async def _run_remove_rule() -> None:
+ client = make_client()
+ seg = await client.get_segment(key)
+ rules = list(seg.get("rules") or [])
+ original_len = len(rules)
+ rules = [r for r in rules if r.get("id") != rule_id]
+ if len(rules) == original_len:
+ console.print(
+ f"[red]Error:[/red] no rule with id '{rule_id}' found on segment '{key}'."
+ )
+ raise typer.Exit(1)
+ seg["rules"] = rules
+ await client.update_segment(key, seg)
+ console.print(
+ f"[green]✓[/green] Rule [dim]{rule_id}[/dim] removed from segment "
+ f"[bold cyan]{key}[/bold cyan]."
+ )
+
+ _run(_run_remove_rule)
+
+
if __name__ == "__main__":
cli()
diff --git a/shield/core/backends/base.py b/shield/core/backends/base.py
index 9b3db76..e26e5a6 100644
--- a/shield/core/backends/base.py
+++ b/shield/core/backends/base.py
@@ -279,3 +279,72 @@ async def subscribe_rate_limit_policy(self) -> AsyncIterator[dict[str, Any]]:
f"{type(self).__name__} does not support rate limit policy pub/sub."
)
yield # make this a valid async generator
+
+ # ------------------------------------------------------------------
+ # Feature flag storage — concrete in-memory default implementations
+ #
+ # All backends get basic in-memory flag/segment storage for free.
+ # FileBackend and RedisBackend can override for persistence.
+ # Storage is lazily initialised on first use so existing backends
+ # that do not call super().__init__() are not affected.
+ # ------------------------------------------------------------------
+
+ def _flag_store(self) -> dict[str, Any]:
+ """Lazy per-instance dict for flag objects."""
+ if not hasattr(self, "_flag_store_dict"):
+ object.__setattr__(self, "_flag_store_dict", {})
+ return self._flag_store_dict # type: ignore[attr-defined, no-any-return]
+
+ def _segment_store(self) -> dict[str, Any]:
+ """Lazy per-instance dict for segment objects."""
+ if not hasattr(self, "_segment_store_dict"):
+ object.__setattr__(self, "_segment_store_dict", {})
+ return self._segment_store_dict # type: ignore[attr-defined, no-any-return]
+
+ async def load_all_flags(self) -> list[Any]:
+ """Return all stored feature flags.
+
+ Returns a list of ``FeatureFlag`` objects. The default
+ implementation uses an in-memory store. Override for persistent
+ backends.
+ """
+ return list(self._flag_store().values())
+
+ async def save_flag(self, flag: Any) -> None:
+ """Persist *flag* (a ``FeatureFlag`` instance) by its key.
+
+ Default implementation keeps flags in memory. Override for
+ persistent backends.
+ """
+ self._flag_store()[flag.key] = flag
+
+ async def delete_flag(self, flag_key: str) -> None:
+ """Remove the flag with *flag_key* from storage.
+
+ No-op if the flag does not exist.
+ """
+ self._flag_store().pop(flag_key, None)
+
+ async def load_all_segments(self) -> list[Any]:
+ """Return all stored segments.
+
+ Returns a list of ``Segment`` objects. The default
+ implementation uses an in-memory store. Override for persistent
+ backends.
+ """
+ return list(self._segment_store().values())
+
+ async def save_segment(self, segment: Any) -> None:
+ """Persist *segment* (a ``Segment`` instance) by its key.
+
+ Default implementation keeps segments in memory. Override for
+ persistent backends.
+ """
+ self._segment_store()[segment.key] = segment
+
+ async def delete_segment(self, segment_key: str) -> None:
+ """Remove the segment with *segment_key* from storage.
+
+ No-op if the segment does not exist.
+ """
+ self._segment_store().pop(segment_key, None)
diff --git a/shield/core/backends/memory.py b/shield/core/backends/memory.py
index 857a69a..f1c354c 100644
--- a/shield/core/backends/memory.py
+++ b/shield/core/backends/memory.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
+import contextlib
from collections import defaultdict, deque
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any
@@ -69,7 +70,7 @@ async def set_state(self, path: str, state: RouteState) -> None:
"""Persist *state* for *path* and notify any subscribers."""
self._states[path] = state
for queue in self._subscribers:
- await queue.put(state)
+ queue.put_nowait(state)
async def delete_state(self, path: str) -> None:
"""Remove state for *path*. No-op if not registered."""
@@ -120,7 +121,8 @@ async def subscribe(self) -> AsyncIterator[RouteState]:
state = await queue.get()
yield state
finally:
- self._subscribers.remove(queue)
+ with contextlib.suppress(ValueError):
+ self._subscribers.remove(queue)
async def write_rate_limit_hit(self, hit: RateLimitHit) -> None:
"""Append a rate limit hit record, evicting the oldest when the cap is reached."""
@@ -155,7 +157,7 @@ async def set_rate_limit_policy(
self._rl_policies[key] = policy_data
event: dict[str, Any] = {"action": "set", "key": key, "policy": policy_data}
for q in self._rl_policy_subscribers:
- await q.put(event)
+ q.put_nowait(event)
async def get_rate_limit_policies(self) -> list[dict[str, Any]]:
"""Return all persisted rate limit policies."""
@@ -167,7 +169,7 @@ async def delete_rate_limit_policy(self, path: str, method: str) -> None:
self._rl_policies.pop(key, None)
event: dict[str, Any] = {"action": "delete", "key": key}
for q in self._rl_policy_subscribers:
- await q.put(event)
+ q.put_nowait(event)
async def subscribe_rate_limit_policy(self) -> AsyncIterator[dict[str, Any]]:
"""Yield rate limit policy change events as they occur."""
@@ -177,4 +179,5 @@ async def subscribe_rate_limit_policy(self) -> AsyncIterator[dict[str, Any]]:
while True:
yield await queue.get()
finally:
- self._rl_policy_subscribers.remove(queue)
+ with contextlib.suppress(ValueError):
+ self._rl_policy_subscribers.remove(queue)
diff --git a/shield/core/backends/server.py b/shield/core/backends/server.py
index e231f57..2fa108c 100644
--- a/shield/core/backends/server.py
+++ b/shield/core/backends/server.py
@@ -96,6 +96,11 @@ def __init__(
self._rl_policy_cache: dict[str, dict[str, Any]] = {}
self._rl_policy_subscribers: list[asyncio.Queue[dict[str, Any]]] = []
+ # Local feature flag / segment cache (populated by SSE flag events).
+ self._flag_cache: dict[str, Any] = {} # key → FeatureFlag raw dict
+ self._segment_cache: dict[str, Any] = {} # key → Segment raw dict
+ self._flag_subscribers: list[asyncio.Queue[dict[str, Any]]] = []
+
self._client: httpx.AsyncClient | None = None
self._sse_task: asyncio.Task[None] | None = None
@@ -368,6 +373,64 @@ async def _listen_sse(self) -> None:
"ShieldServerBackend[%s]: RL policy deleted — %s", self._app_id, key
)
+ elif event_type == "flag_updated":
+ key = envelope.get("key", "")
+ flag_data = envelope.get("flag")
+ if key and flag_data is not None:
+ self._flag_cache[key] = flag_data
+ flag_event: dict[str, Any] = {
+ "type": "flag_updated",
+ "key": key,
+ "flag": flag_data,
+ }
+ for q in self._flag_subscribers:
+ q.put_nowait(flag_event)
+ logger.debug(
+ "ShieldServerBackend[%s]: flag cache updated — %s",
+ self._app_id,
+ key,
+ )
+
+ elif event_type == "flag_deleted":
+ key = envelope.get("key", "")
+ if key:
+ self._flag_cache.pop(key, None)
+ flag_del_event: dict[str, Any] = {"type": "flag_deleted", "key": key}
+ for q in self._flag_subscribers:
+ q.put_nowait(flag_del_event)
+ logger.debug(
+ "ShieldServerBackend[%s]: flag deleted — %s", self._app_id, key
+ )
+
+ elif event_type == "segment_updated":
+ key = envelope.get("key", "")
+ seg_data = envelope.get("segment")
+ if key and seg_data is not None:
+ self._segment_cache[key] = seg_data
+ seg_event: dict[str, Any] = {
+ "type": "segment_updated",
+ "key": key,
+ "segment": seg_data,
+ }
+ for q in self._flag_subscribers:
+ q.put_nowait(seg_event)
+ logger.debug(
+ "ShieldServerBackend[%s]: segment cache updated — %s",
+ self._app_id,
+ key,
+ )
+
+ elif event_type == "segment_deleted":
+ key = envelope.get("key", "")
+ if key:
+ self._segment_cache.pop(key, None)
+ seg_del_event: dict[str, Any] = {"type": "segment_deleted", "key": key}
+ for q in self._flag_subscribers:
+ q.put_nowait(seg_del_event)
+ logger.debug(
+ "ShieldServerBackend[%s]: segment deleted — %s", self._app_id, key
+ )
+
else:
# Legacy plain-RouteState payload (old server without typed envelopes).
try:
@@ -537,4 +600,36 @@ async def subscribe_rate_limit_policy(self) -> AsyncIterator[dict[str, Any]]:
while True:
yield await queue.get()
finally:
- self._rl_policy_subscribers.remove(queue)
+ with contextlib.suppress(ValueError):
+ self._rl_policy_subscribers.remove(queue)
+
+ async def subscribe_flag_changes(self) -> AsyncIterator[dict[str, Any]]:
+ """Yield feature flag / segment change events pushed via the SSE connection.
+
+ Each yielded dict has one of these shapes::
+
+ {"type": "flag_updated", "key": "my-flag", "flag": {...}}
+ {"type": "flag_deleted", "key": "my-flag"}
+ {"type": "segment_updated", "key": "my-seg", "segment": {...}}
+ {"type": "segment_deleted", "key": "my-seg"}
+ """
+ queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
+ self._flag_subscribers.append(queue)
+ try:
+ while True:
+ yield await queue.get()
+ finally:
+ with contextlib.suppress(ValueError):
+ self._flag_subscribers.remove(queue)
+
+ # ------------------------------------------------------------------
+ # Feature flag storage — returns locally cached data fetched via SSE
+ # ------------------------------------------------------------------
+
+ async def load_all_flags(self) -> list[Any]:
+ """Return all feature flags cached from the Shield Server."""
+ return list(self._flag_cache.values())
+
+ async def load_all_segments(self) -> list[Any]:
+ """Return all segments cached from the Shield Server."""
+ return list(self._segment_cache.values())
diff --git a/shield/core/engine.py b/shield/core/engine.py
index d1d9851..ead8155 100644
--- a/shield/core/engine.py
+++ b/shield/core/engine.py
@@ -284,6 +284,34 @@ def get_audit_log(self, path: str | None = None, limit: int = 100) -> list[Audit
"""Sync version of :meth:`ShieldEngine.get_audit_log`."""
return self._run(self._engine.get_audit_log(path=path, limit=limit))
+ # ------------------------------------------------------------------
+ # Feature flags
+ # ------------------------------------------------------------------
+
+ @property
+ def flag_client(self) -> Any:
+ """Return the synchronous flag client, or ``None`` if flags are not active.
+
+ Call ``engine.use_openfeature()`` first to activate the flag system.
+
+ Since OpenFeature evaluation is CPU-bound, this client does **not**
+ require a thread bridge — all methods are safe to call directly from
+ a ``def`` handler running in an anyio worker thread.
+
+ Example::
+
+ @router.get("/checkout")
+ def checkout(request: Request):
+ enabled = engine.sync.flag_client.get_boolean_value(
+ "new_checkout", False, {"targeting_key": request.state.user_id}
+ )
+ return checkout_v2() if enabled else checkout_v1()
+ """
+ fc = self._engine._flag_client
+ if fc is None:
+ return None
+ return fc.sync
+
class ShieldEngine:
"""Central orchestrator — all route lifecycle logic flows through here.
@@ -344,6 +372,10 @@ def __init__(
self._global_rate_limit_policy: Any = None # GlobalRateLimitPolicy | None
# Sync proxy — created once, reused on every engine.sync access.
self.sync: _SyncProxy = _SyncProxy(self)
+ # Feature flags — lazily set by use_openfeature().
+ self._flag_provider: Any = None # ShieldOpenFeatureProvider | None
+ self._flag_client: Any = None # ShieldFeatureClient | None
+ self._flag_scheduler: Any = None # FlagScheduler | None (set by use_openfeature)
# ------------------------------------------------------------------
# Async context manager — calls backend lifecycle hooks
@@ -407,6 +439,19 @@ async def start(self) -> None:
self._run_rl_policy_listener(),
name="shield-rl-policy-listener",
)
+ if self._flag_provider is not None:
+ # The OpenFeature SDK calls initialize() synchronously at
+ # set_provider() time. For async overrides the SDK silently
+ # discards the coroutine; engine.start() detects and awaits it.
+ # For sync initialize (including the base-class no-op) the SDK
+ # already ran it, so we skip the redundant call and go straight
+ # to warming the async backend cache.
+ if asyncio.iscoroutinefunction(type(self._flag_provider).initialize):
+ await self._flag_provider.initialize()
+ else:
+ await self._flag_provider._load_all()
+ if self._flag_scheduler is not None:
+ await self._flag_scheduler.start()
async def stop(self) -> None:
"""Cancel background listener tasks and wait for them to finish.
@@ -429,6 +474,157 @@ async def stop(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
await self._rl_policy_listener_task
self._rl_policy_listener_task = None
+ if self._flag_scheduler is not None:
+ await self._flag_scheduler.stop()
+ if self._flag_provider is not None:
+ if asyncio.iscoroutinefunction(type(self._flag_provider).shutdown):
+ await self._flag_provider.shutdown()
+ else:
+ self._flag_provider.shutdown()
+
+ # ------------------------------------------------------------------
+ # Feature flags — OpenFeature wiring
+ # ------------------------------------------------------------------
+
+ def use_openfeature(
+ self,
+ provider: Any = None,
+ hooks: list[Any] | None = None,
+ domain: str = "shield",
+ ) -> Any:
+ """Activate the feature flag system backed by this engine's backend.
+
+ Parameters
+ ----------
+ provider:
+ An OpenFeature-compliant provider to use. Defaults to
+ ``ShieldOpenFeatureProvider(self.backend)`` — the built-in
+ provider backed by the same backend as the engine.
+ hooks:
+ Additional OpenFeature hooks to register globally. Default
+ hooks (``LoggingHook``) are always added.
+ domain:
+ The OpenFeature domain name for the client. Defaults to
+ ``"shield"``.
+
+ Returns
+ -------
+ ShieldFeatureClient
+ The feature client ready for flag evaluations.
+
+ Raises
+ ------
+ ImportError
+ When ``api-shield[flags]`` is not installed.
+ """
+ from shield.core.feature_flags._guard import _require_flags
+
+ _require_flags()
+
+ import openfeature.api as of_api
+ from openfeature.hook import Hook
+
+ from shield.core.feature_flags.client import ShieldFeatureClient
+ from shield.core.feature_flags.hooks import LoggingHook
+ from shield.core.feature_flags.provider import ShieldOpenFeatureProvider
+
+ if provider is None:
+ provider = ShieldOpenFeatureProvider(self.backend)
+
+ self._flag_provider = provider
+
+ # Register the provider under the given domain (OpenFeature >=0.8 API).
+ try:
+ of_api.set_provider(provider, domain=domain)
+ except TypeError:
+ # Older openfeature-sdk versions without domain support.
+ of_api.set_provider(provider)
+
+ from shield.core.feature_flags.hooks import MetricsHook
+
+ metrics_hook = MetricsHook()
+
+ # Build the default hook list and merge with any user-supplied hooks.
+ default_hooks: list[Hook] = [LoggingHook(), metrics_hook]
+ all_hooks = default_hooks + (hooks or [])
+ of_api.add_hooks(all_hooks)
+
+ # Create and cache the client.
+ self._flag_client = ShieldFeatureClient(domain=domain)
+
+ # Create the scheduler (start() is called later in engine.start()).
+ from shield.core.feature_flags.scheduler import FlagScheduler
+
+ self._flag_scheduler = FlagScheduler(self)
+
+ return self._flag_client
+
+ @property
+ def flag_client(self) -> Any:
+ """Return the active ``ShieldFeatureClient``, or ``None`` if not configured.
+
+ Call ``engine.use_openfeature()`` first to activate the flag system.
+ """
+ return self._flag_client
+
+ @property
+ def flag_scheduler(self) -> Any:
+ """Return the active ``FlagScheduler``, or ``None`` if not configured."""
+ return self._flag_scheduler
+
+ # ------------------------------------------------------------------
+ # Feature flag CRUD — single chokepoint for flag + segment operations
+ # ------------------------------------------------------------------
+
+ async def list_flags(self) -> list[Any]:
+ """Return all feature flags from the provider cache (or backend)."""
+ if self._flag_provider is not None:
+ return list(self._flag_provider._flags.values())
+ return await self.backend.load_all_flags()
+
+ async def get_flag(self, key: str) -> Any:
+ """Return a single ``FeatureFlag`` by *key*, or ``None`` if not found."""
+ if self._flag_provider is not None:
+ return self._flag_provider._flags.get(key)
+ flags = await self.backend.load_all_flags()
+ return next((f for f in flags if f.key == key), None)
+
+ async def save_flag(self, flag: Any) -> None:
+ """Persist *flag* to the backend and update the provider cache."""
+ await self.backend.save_flag(flag)
+ if self._flag_provider is not None:
+ self._flag_provider.upsert_flag(flag)
+
+ async def delete_flag(self, key: str) -> None:
+ """Delete a flag by *key* from the backend and provider cache."""
+ await self.backend.delete_flag(key)
+ if self._flag_provider is not None:
+ self._flag_provider.delete_flag(key)
+
+ async def list_segments(self) -> list[Any]:
+ """Return all segments from the provider cache (or backend)."""
+ if self._flag_provider is not None:
+ return list(self._flag_provider._segments.values())
+ return await self.backend.load_all_segments()
+
+ async def get_segment(self, key: str) -> Any:
+ """Return a single ``Segment`` by *key*, or ``None`` if not found."""
+ if self._flag_provider is not None:
+ return self._flag_provider._segments.get(key)
+ segments = await self.backend.load_all_segments()
+ return next((s for s in segments if s.key == key), None)
+
+ async def save_segment(self, segment: Any) -> None:
+ """Persist *segment* to the backend and update the provider cache."""
+ await self.backend.save_segment(segment)
+ if self._flag_provider is not None:
+ self._flag_provider.upsert_segment(segment)
+
+ async def delete_segment(self, key: str) -> None:
+ """Delete a segment by *key* from the backend and provider cache."""
+ await self.backend.delete_segment(key)
+ if self._flag_provider is not None:
+ self._flag_provider.delete_segment(key)
async def _run_global_config_listener(self) -> None:
"""Background coroutine: invalidate the global config cache on remote changes.
diff --git a/shield/core/feature_flags/__init__.py b/shield/core/feature_flags/__init__.py
new file mode 100644
index 0000000..62789c1
--- /dev/null
+++ b/shield/core/feature_flags/__init__.py
@@ -0,0 +1,153 @@
+"""shield.core.feature_flags — OpenFeature-compliant feature flag system.
+
+This package requires the [flags] optional extra::
+
+ pip install api-shield[flags]
+
+Importing from this package when the extra is not installed raises an
+``ImportError`` with clear installation instructions.
+
+All public symbols are re-exported under Shield-namespaced names.
+``openfeature`` never appears in user-facing imports.
+
+Usage
+-----
+::
+
+ from shield.core.feature_flags import (
+ EvaluationContext,
+ ShieldFeatureClient,
+ EvaluationReason,
+ ResolutionDetails,
+ )
+
+ ctx = EvaluationContext(key=user_id, attributes={"plan": "pro"})
+ value = await flag_client.get_boolean_value("new_checkout", False, ctx)
+
+Custom provider (implements OpenFeature's AbstractProvider)::
+
+ from shield.core.feature_flags import ShieldFlagProvider
+
+ class MyProvider(ShieldFlagProvider):
+ ...
+
+Custom hook (implements OpenFeature's Hook interface)::
+
+ from shield.core.feature_flags import ShieldHook
+"""
+
+from __future__ import annotations
+
+# ── Guard: raise early with a helpful message if openfeature not installed ──
+from shield.core.feature_flags._guard import _require_flags
+
+_require_flags()
+
+# ── OpenFeature ABC re-exports (Shield-namespaced) ──────────────────────────
+# These are the extension points for users who want custom providers/hooks.
+from openfeature.hook import Hook as ShieldHook
+from openfeature.provider import AbstractProvider as ShieldFlagProvider
+
+# ── Client and provider re-exports ──────────────────────────────────────────
+# Imported lazily here so the module graph stays clean.
+# client.py and provider.py each call _require_flags() themselves.
+from shield.core.feature_flags.client import ShieldFeatureClient as ShieldFeatureClient
+
+# ── Hook re-exports ─────────────────────────────────────────────────────────
+from shield.core.feature_flags.hooks import (
+ AuditHook as AuditHook,
+)
+from shield.core.feature_flags.hooks import (
+ LoggingHook as LoggingHook,
+)
+from shield.core.feature_flags.hooks import (
+ MetricsHook as MetricsHook,
+)
+from shield.core.feature_flags.hooks import (
+ OpenTelemetryHook as OpenTelemetryHook,
+)
+
+# ── Shield-native model re-exports ──────────────────────────────────────────
+from shield.core.feature_flags.models import (
+ EvaluationContext as EvaluationContext,
+)
+from shield.core.feature_flags.models import (
+ EvaluationReason as EvaluationReason,
+)
+from shield.core.feature_flags.models import (
+ FeatureFlag as FeatureFlag,
+)
+from shield.core.feature_flags.models import (
+ FlagStatus as FlagStatus,
+)
+from shield.core.feature_flags.models import (
+ FlagType as FlagType,
+)
+from shield.core.feature_flags.models import (
+ FlagVariation as FlagVariation,
+)
+from shield.core.feature_flags.models import (
+ Operator as Operator,
+)
+from shield.core.feature_flags.models import (
+ Prerequisite as Prerequisite,
+)
+from shield.core.feature_flags.models import (
+ ResolutionDetails as ResolutionDetails,
+)
+from shield.core.feature_flags.models import (
+ RolloutVariation as RolloutVariation,
+)
+from shield.core.feature_flags.models import (
+ RuleClause as RuleClause,
+)
+from shield.core.feature_flags.models import (
+ ScheduledChange as ScheduledChange,
+)
+from shield.core.feature_flags.models import (
+ ScheduledChangeAction as ScheduledChangeAction,
+)
+from shield.core.feature_flags.models import (
+ Segment as Segment,
+)
+from shield.core.feature_flags.models import (
+ SegmentRule as SegmentRule,
+)
+from shield.core.feature_flags.models import (
+ TargetingRule as TargetingRule,
+)
+from shield.core.feature_flags.provider import (
+ ShieldOpenFeatureProvider as ShieldOpenFeatureProvider,
+)
+
+__all__ = [
+ # Extension points
+ "ShieldFlagProvider",
+ "ShieldHook",
+ # Models
+ "EvaluationContext",
+ "EvaluationReason",
+ "FeatureFlag",
+ "FlagStatus",
+ "FlagType",
+ "FlagVariation",
+ "Operator",
+ "Prerequisite",
+ "ResolutionDetails",
+ "RolloutVariation",
+ "RuleClause",
+ "ScheduledChange",
+ "ScheduledChangeAction",
+ "Segment",
+ "SegmentRule",
+ "TargetingRule",
+ # Client
+ "ShieldFeatureClient",
+ # Provider
+ "ShieldOpenFeatureProvider",
+ # Hooks
+ "AuditHook",
+ "LoggingHook",
+ "MetricsHook",
+ "OpenTelemetryHook",
+]
diff --git a/shield/core/feature_flags/_context.py b/shield/core/feature_flags/_context.py
new file mode 100644
index 0000000..88b7671
--- /dev/null
+++ b/shield/core/feature_flags/_context.py
@@ -0,0 +1,65 @@
+"""Context conversion helpers between Shield and OpenFeature types.
+
+Converts ``shield.core.feature_flags.models.EvaluationContext`` →
+``openfeature.evaluation_context.EvaluationContext`` for provider dispatch,
+and back again for the native provider's evaluator calls.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from shield.core.feature_flags.models import EvaluationContext as ShieldContext
+
+
+def to_of_context(ctx: ShieldContext | None) -> object | None:
+ """Convert a Shield ``EvaluationContext`` to an OpenFeature one.
+
+ Returns ``None`` when *ctx* is ``None`` (OpenFeature accepts ``None``
+ to mean "use global context").
+
+ Also accepts plain ``dict`` for convenience in sync callers — the
+ ``targeting_key`` entry is mapped to the OpenFeature targeting key.
+ """
+ if ctx is None:
+ return None
+
+ from openfeature.evaluation_context import EvaluationContext as OFContext
+
+ if isinstance(ctx, dict):
+ d = dict(ctx)
+ targeting_key = d.pop("targeting_key", "anonymous")
+ return OFContext(targeting_key=targeting_key, attributes=d)
+
+ attrs = ctx.all_attributes()
+ # targeting_key is the OpenFeature equivalent of our ctx.key
+ targeting_key = attrs.pop("key", ctx.key)
+ return OFContext(targeting_key=targeting_key, attributes=attrs)
+
+
+def from_of_context(of_ctx: object | None) -> ShieldContext:
+ """Convert an OpenFeature ``EvaluationContext`` to a Shield one.
+
+ Used inside ``ShieldOpenFeatureProvider`` when the OpenFeature SDK
+ dispatches a resolution call so that ``FlagEvaluator`` receives the
+ right type.
+ """
+ from shield.core.feature_flags.models import EvaluationContext as ShieldContext
+
+ if of_ctx is None:
+ return ShieldContext(key="anonymous")
+
+ # OpenFeature EvaluationContext has targeting_key + attributes
+ targeting_key = getattr(of_ctx, "targeting_key", None) or "anonymous"
+ attributes: dict[str, Any] = getattr(of_ctx, "attributes", {}) or {}
+
+ return ShieldContext(
+ key=targeting_key,
+ kind=attributes.pop("kind", "user"),
+ email=attributes.pop("email", None),
+ ip=attributes.pop("ip", None),
+ country=attributes.pop("country", None),
+ app_version=attributes.pop("app_version", None),
+ attributes=attributes,
+ )
diff --git a/shield/core/feature_flags/_guard.py b/shield/core/feature_flags/_guard.py
new file mode 100644
index 0000000..616c157
--- /dev/null
+++ b/shield/core/feature_flags/_guard.py
@@ -0,0 +1,26 @@
+"""Import guard for the feature flags optional dependency.
+
+Call ``_require_flags()`` at the top of any module that needs
+``openfeature`` before attempting to import it. This produces a clear,
+actionable error message instead of a bare ``ModuleNotFoundError``.
+
+``shield/core/feature_flags/models.py`` and ``evaluator.py`` are pure
+Pydantic/stdlib and do **not** call this guard — they are importable
+regardless of whether the [flags] extra is installed. Only the public
+``shield.core.feature_flags`` namespace (``__init__.py``) and the
+provider/client modules call this guard.
+"""
+
+from __future__ import annotations
+
+
+def _require_flags() -> None:
+ """Raise ``ImportError`` with install instructions if openfeature is missing."""
+ try:
+ import openfeature # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Feature flags require the [flags] extra.\n"
+ "Install with: pip install api-shield[flags]\n"
+ "Or: uv pip install 'api-shield[flags]'"
+ ) from None
diff --git a/shield/core/feature_flags/client.py b/shield/core/feature_flags/client.py
new file mode 100644
index 0000000..9918578
--- /dev/null
+++ b/shield/core/feature_flags/client.py
@@ -0,0 +1,171 @@
+"""ShieldFeatureClient — OpenFeature-backed flag evaluation API.
+
+Phase 2 implementation. Stub present so the package imports cleanly.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+from shield.core.feature_flags._guard import _require_flags
+
+_require_flags()
+
+
+class _SyncShieldFeatureClient:
+ """Synchronous façade over :class:`ShieldFeatureClient`.
+
+ Access via ``engine.sync.flag_client`` from sync route handlers.
+ FastAPI runs ``def`` handlers in anyio worker threads, which is exactly
+ the context this class is designed for.
+
+ Because OpenFeature evaluation is CPU-bound (pure Python, no I/O), all
+ methods call the underlying OpenFeature client directly — no thread
+ bridge or event-loop interaction needed.
+
+ Examples
+ --------
+ ::
+
+ @router.get("/checkout")
+ def checkout(request: Request):
+ enabled = engine.sync.flag_client.get_boolean_value(
+ "new_checkout", False, {"targeting_key": request.state.user_id}
+ )
+ if enabled:
+ return checkout_v2()
+ return checkout_v1()
+ """
+
+ __slots__ = ("_of_client",)
+
+ def __init__(self, of_client: object) -> None:
+ # ``of_client`` is the raw openfeature Client, not ShieldFeatureClient.
+ self._of_client = of_client
+
+ def get_boolean_value(
+ self,
+ flag_key: str,
+ default: bool,
+ ctx: object | None = None,
+ ) -> bool:
+ """Evaluate a boolean flag synchronously."""
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._of_client.get_boolean_value(flag_key, default, to_of_context(ctx)) # type: ignore[attr-defined, no-any-return, arg-type]
+
+ def get_string_value(
+ self,
+ flag_key: str,
+ default: str,
+ ctx: object | None = None,
+ ) -> str:
+ """Evaluate a string flag synchronously."""
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._of_client.get_string_value(flag_key, default, to_of_context(ctx)) # type: ignore[attr-defined, no-any-return, arg-type]
+
+ def get_integer_value(
+ self,
+ flag_key: str,
+ default: int,
+ ctx: object | None = None,
+ ) -> int:
+ """Evaluate an integer flag synchronously."""
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._of_client.get_integer_value(flag_key, default, to_of_context(ctx)) # type: ignore[attr-defined, no-any-return, arg-type]
+
+ def get_float_value(
+ self,
+ flag_key: str,
+ default: float,
+ ctx: object | None = None,
+ ) -> float:
+ """Evaluate a float flag synchronously."""
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._of_client.get_float_value(flag_key, default, to_of_context(ctx)) # type: ignore[attr-defined, no-any-return, arg-type]
+
+ def get_object_value(
+ self,
+ flag_key: str,
+ default: dict, # type: ignore[type-arg]
+ ctx: object | None = None,
+ ) -> dict: # type: ignore[type-arg]
+ """Evaluate a JSON/object flag synchronously."""
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._of_client.get_object_value(flag_key, default, to_of_context(ctx)) # type: ignore[attr-defined, no-any-return, arg-type]
+
+
+class ShieldFeatureClient:
+ """Thin wrapper around the OpenFeature client.
+
+ Instantiated via ``engine.use_openfeature()``.
+ Do not construct directly.
+ """
+
+ def __init__(self, domain: str = "shield") -> None:
+ from openfeature import api
+
+ self._client = api.get_client(domain)
+ self._domain = domain
+
+ async def get_boolean_value(
+ self,
+ flag_key: str,
+ default: bool,
+ ctx: object | None = None,
+ ) -> bool:
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._client.get_boolean_value(flag_key, default, to_of_context(ctx)) # type: ignore[arg-type]
+
+ async def get_string_value(
+ self,
+ flag_key: str,
+ default: str,
+ ctx: object | None = None,
+ ) -> str:
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._client.get_string_value(flag_key, default, to_of_context(ctx)) # type: ignore[arg-type]
+
+ async def get_integer_value(
+ self,
+ flag_key: str,
+ default: int,
+ ctx: object | None = None,
+ ) -> int:
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._client.get_integer_value(flag_key, default, to_of_context(ctx)) # type: ignore[arg-type]
+
+ async def get_float_value(
+ self,
+ flag_key: str,
+ default: float,
+ ctx: object | None = None,
+ ) -> float:
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._client.get_float_value(flag_key, default, to_of_context(ctx)) # type: ignore[arg-type]
+
+ async def get_object_value(
+ self,
+ flag_key: str,
+ default: dict[str, Any],
+ ctx: object | None = None,
+ ) -> dict[str, Any]:
+ from shield.core.feature_flags._context import to_of_context
+
+ return self._client.get_object_value(flag_key, default, to_of_context(ctx)) # type: ignore[arg-type, return-value]
+
+ @property
+ def sync(self) -> _SyncShieldFeatureClient:
+ """Return a synchronous façade for use in ``def`` (non-async) handlers.
+
+ Prefer ``engine.sync.flag_client`` over accessing this directly.
+ """
+ return _SyncShieldFeatureClient(self._client)
diff --git a/shield/core/feature_flags/evaluator.py b/shield/core/feature_flags/evaluator.py
new file mode 100644
index 0000000..f7d6456
--- /dev/null
+++ b/shield/core/feature_flags/evaluator.py
@@ -0,0 +1,443 @@
+"""Pure feature flag evaluation engine.
+
+No I/O, no async, no openfeature dependency. Fully unit-testable in
+isolation by constructing ``FeatureFlag`` and ``EvaluationContext``
+objects directly.
+
+Evaluation order
+----------------
+1. Flag disabled (``enabled=False``) → ``off_variation``
+2. Prerequisites — recursive, short-circuits on first failure
+3. Individual targets — ``flag.targets[variation]`` contains ``ctx.key``
+4. Rules — top-to-bottom, first matching rule wins
+5. Fallthrough — fixed variation or percentage rollout bucket
+
+Clause semantics
+----------------
+- All clauses within a rule are AND-ed (all must match).
+- Multiple values within one clause are OR-ed (any value must match).
+- ``negate=True`` inverts the final result of the clause.
+
+Rollout bucketing
+-----------------
+SHA-1 hash of ``"{flag_key}:{ctx.kind}:{ctx.key}"`` modulo 100_000.
+Deterministic and stable — the same context always lands in the same
+bucket. Weights in ``RolloutVariation`` lists should sum to 100_000.
+"""
+
+from __future__ import annotations
+
+import hashlib
+import logging
+import re
+from typing import Any
+
+from shield.core.feature_flags.models import (
+ EvaluationContext,
+ EvaluationReason,
+ FeatureFlag,
+ Operator,
+ ResolutionDetails,
+ RolloutVariation,
+ RuleClause,
+ Segment,
+ TargetingRule,
+)
+
+logger = logging.getLogger(__name__)
+
+# Maximum prerequisite recursion depth to prevent accidental infinite loops.
+_MAX_PREREQ_DEPTH = 10
+
+
+class FlagEvaluator:
+ """Evaluate feature flags against an evaluation context.
+
+ Parameters
+ ----------
+ segments:
+ Preloaded mapping of segment key → ``Segment``. Pass an empty
+ dict if no segments are defined. Updated in-place by the
+ provider on hot-reload.
+
+ Examples
+ --------
+ ::
+
+ evaluator = FlagEvaluator(segments={"beta": beta_segment})
+ result = evaluator.evaluate(flag, ctx, all_flags)
+ print(result.value, result.reason)
+ """
+
+ def __init__(self, segments: dict[str, Segment]) -> None:
+ self._segments = segments
+
+ # ── Public interface ────────────────────────────────────────────────────
+
+ def evaluate(
+ self,
+ flag: FeatureFlag,
+ ctx: EvaluationContext,
+ all_flags: dict[str, FeatureFlag],
+ *,
+ _depth: int = 0,
+ ) -> ResolutionDetails:
+ """Evaluate *flag* for *ctx* and return a ``ResolutionDetails``.
+
+ Parameters
+ ----------
+ flag:
+ The flag to evaluate.
+ ctx:
+ Per-request evaluation context.
+ all_flags:
+ Full flag map — required for prerequisite resolution.
+ _depth:
+ Internal recursion counter. Do not pass from call sites.
+ """
+ if _depth > _MAX_PREREQ_DEPTH:
+ logger.error(
+ "api-shield flags: prerequisite depth limit reached for flag '%s'. "
+ "Serving off_variation to prevent infinite recursion.",
+ flag.key,
+ )
+ return self._off(
+ flag,
+ reason=EvaluationReason.ERROR,
+ error_message="Prerequisite depth limit exceeded",
+ )
+
+ # Step 1: global kill-switch
+ if not flag.enabled:
+ return self._off(flag, reason=EvaluationReason.OFF)
+
+ # Step 2: prerequisites
+ for prereq in flag.prerequisites:
+ prereq_flag = all_flags.get(prereq.flag_key)
+ if prereq_flag is None:
+ logger.warning(
+ "api-shield flags: prerequisite flag '%s' not found "
+ "for flag '%s'. Serving off_variation.",
+ prereq.flag_key,
+ flag.key,
+ )
+ return self._off(
+ flag,
+ reason=EvaluationReason.PREREQUISITE_FAIL,
+ prerequisite_key=prereq.flag_key,
+ )
+ prereq_result = self.evaluate(prereq_flag, ctx, all_flags, _depth=_depth + 1)
+ if prereq_result.variation != prereq.variation:
+ return self._off(
+ flag,
+ reason=EvaluationReason.PREREQUISITE_FAIL,
+ prerequisite_key=prereq.flag_key,
+ )
+
+ # Step 3: individual targets
+ for variation_name, keys in flag.targets.items():
+ if ctx.key in keys:
+ return ResolutionDetails(
+ value=flag.get_variation_value(variation_name),
+ variation=variation_name,
+ reason=EvaluationReason.TARGET_MATCH,
+ )
+
+ # Step 4: targeting rules (top-to-bottom, first match wins)
+ for rule in flag.rules:
+ if self._rule_matches(rule, ctx):
+ variation_name = self._resolve_rule_variation(rule, ctx, flag)
+ return ResolutionDetails(
+ value=flag.get_variation_value(variation_name),
+ variation=variation_name,
+ reason=EvaluationReason.RULE_MATCH,
+ rule_id=rule.id,
+ )
+
+ # Step 5: fallthrough (default rule)
+ variation_name = self._resolve_fallthrough(flag, ctx)
+ return ResolutionDetails(
+ value=flag.get_variation_value(variation_name),
+ variation=variation_name,
+ reason=EvaluationReason.FALLTHROUGH,
+ )
+
+ # ── Rule and clause matching ────────────────────────────────────────────
+
+ def _rule_matches(self, rule: TargetingRule, ctx: EvaluationContext) -> bool:
+ """Return ``True`` if ALL clauses in *rule* match *ctx* (AND logic)."""
+ return all(self._clause_matches(clause, ctx) for clause in rule.clauses)
+
+ def _clause_matches(self, clause: RuleClause, ctx: EvaluationContext) -> bool:
+ """Evaluate a single clause against the context.
+
+ Applies the operator, then inverts the result if ``negate=True``.
+ Returns ``False`` when the attribute is missing and the operator
+ requires a value (safe default — missing attribute → no match).
+ """
+ attrs = ctx.all_attributes()
+ actual = attrs.get(clause.attribute)
+ result = self._apply_operator(clause.operator, actual, clause.values)
+ return not result if clause.negate else result
+
+ def _apply_operator(self, op: Operator, actual: Any, values: list[Any]) -> bool:
+ """Apply *op* comparing *actual* against *values*.
+
+ Multiple values use OR logic — returns ``True`` if any value matches.
+ Missing ``actual`` (``None``) returns ``False`` for all operators
+ except ``IS_NOT`` and ``NOT_IN``.
+ """
+ # Segment operators delegate to _in_segment
+ if op == Operator.IN_SEGMENT:
+ return any(self._in_segment(actual, seg_key, _ctx=None) for seg_key in values)
+ if op == Operator.NOT_IN_SEGMENT:
+ return all(not self._in_segment(actual, seg_key, _ctx=None) for seg_key in values)
+
+ if actual is None:
+ # Only IS_NOT and NOT_IN make sense with None
+ if op == Operator.IS_NOT:
+ return all(v is not None for v in values)
+ if op == Operator.NOT_IN:
+ return None not in values
+ return False
+
+ match op:
+ # ── Equality ────────────────────────────────────────────────
+ case Operator.IS:
+ return any(actual == v for v in values)
+ case Operator.IS_NOT:
+ return all(actual != v for v in values)
+ # ── String ──────────────────────────────────────────────────
+ case Operator.CONTAINS:
+ s = str(actual)
+ return any(str(v) in s for v in values)
+ case Operator.NOT_CONTAINS:
+ s = str(actual)
+ return all(str(v) not in s for v in values)
+ case Operator.STARTS_WITH:
+ s = str(actual)
+ return any(s.startswith(str(v)) for v in values)
+ case Operator.ENDS_WITH:
+ s = str(actual)
+ return any(s.endswith(str(v)) for v in values)
+ case Operator.MATCHES:
+ s = str(actual)
+ return any(_safe_regex(str(v), s) for v in values)
+ case Operator.NOT_MATCHES:
+ s = str(actual)
+ return all(not _safe_regex(str(v), s) for v in values)
+ # ── Numeric ─────────────────────────────────────────────────
+ case Operator.GT:
+ return _numeric_op(actual, values[0], lambda a, b: a > b)
+ case Operator.GTE:
+ return _numeric_op(actual, values[0], lambda a, b: a >= b)
+ case Operator.LT:
+ return _numeric_op(actual, values[0], lambda a, b: a < b)
+ case Operator.LTE:
+ return _numeric_op(actual, values[0], lambda a, b: a <= b)
+ # ── Date (ISO-8601 string lexicographic comparison) ──────────
+ case Operator.BEFORE:
+ return str(actual) < str(values[0])
+ case Operator.AFTER:
+ return str(actual) > str(values[0])
+ # ── Collection ──────────────────────────────────────────────
+ case Operator.IN:
+ return actual in values
+ case Operator.NOT_IN:
+ return actual not in values
+ # ── Semantic version ────────────────────────────────────────
+ case Operator.SEMVER_EQ:
+ return _semver_op(actual, values[0], "eq")
+ case Operator.SEMVER_LT:
+ return _semver_op(actual, values[0], "lt")
+ case Operator.SEMVER_GT:
+ return _semver_op(actual, values[0], "gt")
+ case _:
+ logger.warning("api-shield flags: unknown operator '%s'", op)
+ return False
+
+ # ── Segment evaluation ──────────────────────────────────────────────────
+
+ def _in_segment(
+ self,
+ context_key: str | None,
+ segment_key: str,
+ *,
+ _ctx: EvaluationContext | None,
+ ) -> bool:
+ """Return ``True`` if *context_key* is a member of *segment_key*.
+
+ Evaluation order:
+ 1. Key in ``excluded`` → False
+ 2. Key in ``included`` → True
+ 3. Any segment rule matches → True
+ 4. Otherwise → False
+ """
+ if context_key is None:
+ return False
+
+ seg = self._segments.get(segment_key)
+ if seg is None:
+ logger.warning(
+ "api-shield flags: segment '%s' not found — treating as empty.",
+ segment_key,
+ )
+ return False
+
+ if context_key in seg.excluded:
+ return False
+ if context_key in seg.included:
+ return True
+
+ if _ctx is None:
+ # Segment rules need the full context — called from a clause
+ # that only passed the context key, not the full EvaluationContext.
+ # Without the full context we can't evaluate rules.
+ return False
+
+ for rule in seg.rules:
+ if all(self._clause_matches(clause, _ctx) for clause in rule.clauses):
+ return True
+
+ return False
+
+ def _clause_matches_with_ctx(self, clause: RuleClause, ctx: EvaluationContext) -> bool:
+ """Clause match variant that passes *ctx* into segment evaluation."""
+ if clause.operator in (Operator.IN_SEGMENT, Operator.NOT_IN_SEGMENT):
+ actual = ctx.key
+ if clause.operator == Operator.IN_SEGMENT:
+ result = any(
+ self._in_segment(actual, seg_key, _ctx=ctx) for seg_key in clause.values
+ )
+ else:
+ result = all(
+ not self._in_segment(actual, seg_key, _ctx=ctx) for seg_key in clause.values
+ )
+ return not result if clause.negate else result
+ return self._clause_matches(clause, ctx)
+
+ def _rule_matches(self, rule: TargetingRule, ctx: EvaluationContext) -> bool: # type: ignore[no-redef]
+ """Return ``True`` if ALL clauses in *rule* match *ctx*.
+
+ Uses ``_clause_matches_with_ctx`` so that segment operators receive
+ the full context for rule evaluation.
+ """
+ return all(self._clause_matches_with_ctx(clause, ctx) for clause in rule.clauses)
+
+ # ── Rollout and variation resolution ───────────────────────────────────
+
+ def _resolve_rule_variation(
+ self, rule: TargetingRule, ctx: EvaluationContext, flag: FeatureFlag
+ ) -> str:
+ """Return the variation name to serve for a matched rule."""
+ if rule.variation is not None:
+ return rule.variation
+ if rule.rollout:
+ return self._bucket_rollout(rule.rollout, ctx, flag.key)
+ # Malformed rule — fall through to flag default
+ logger.warning(
+ "api-shield flags: rule '%s' on flag '%s' has neither variation "
+ "nor rollout — falling through to default.",
+ rule.id,
+ flag.key,
+ )
+ return self._resolve_fallthrough(flag, ctx)
+
+ def _resolve_fallthrough(self, flag: FeatureFlag, ctx: EvaluationContext) -> str:
+ """Return the variation name for the fallthrough (default) rule."""
+ if isinstance(flag.fallthrough, str):
+ return flag.fallthrough
+ return self._bucket_rollout(flag.fallthrough, ctx, flag.key)
+
+ @staticmethod
+ def _bucket_rollout(
+ rollout: list[RolloutVariation],
+ ctx: EvaluationContext,
+ flag_key: str,
+ ) -> str:
+ """Deterministic bucket assignment for percentage rollouts.
+
+ Uses SHA-1 of ``"{flag_key}:{ctx.kind}:{ctx.key}"`` for stable,
+ consistent assignment. Bucket range is 0–99_999 (100_000 total)
+ matching the weight precision of ``RolloutVariation.weight``.
+
+ Returns the last variation if weights don't sum to 100_000 (safe
+ fallback — never raises).
+ """
+ seed = f"{flag_key}:{ctx.kind}:{ctx.key}"
+ bucket = int(hashlib.sha1(seed.encode()).hexdigest(), 16) % 100_000
+ cumulative = 0
+ for rv in rollout:
+ cumulative += rv.weight
+ if bucket < cumulative:
+ return rv.variation
+ return rollout[-1].variation
+
+ # ── Helpers ─────────────────────────────────────────────────────────────
+
+ @staticmethod
+ def _off(
+ flag: FeatureFlag,
+ *,
+ reason: EvaluationReason,
+ prerequisite_key: str | None = None,
+ error_message: str | None = None,
+ ) -> ResolutionDetails:
+ return ResolutionDetails(
+ value=flag.get_variation_value(flag.off_variation),
+ variation=flag.off_variation,
+ reason=reason,
+ prerequisite_key=prerequisite_key,
+ error_message=error_message,
+ )
+
+
+# ── Module-level helpers ──────────────────────────────────────────────────────
+
+
+def _safe_regex(pattern: str, string: str) -> bool:
+ """Apply regex *pattern* to *string*, returning ``False`` on error."""
+ try:
+ return bool(re.search(pattern, string))
+ except re.error as exc:
+ logger.warning("api-shield flags: invalid regex '%s': %s", pattern, exc)
+ return False
+
+
+def _numeric_op(actual: Any, threshold: Any, comparator: Any) -> bool:
+ """Apply a numeric comparison, returning ``False`` on type errors."""
+ try:
+ return comparator(float(actual), float(threshold)) # type: ignore[no-any-return]
+ except (TypeError, ValueError):
+ return False
+
+
+def _semver_op(actual: Any, threshold: Any, op: str) -> bool:
+ """Apply a semantic version comparison using ``packaging.version``.
+
+ Falls back to ``False`` if ``packaging`` is not installed or the
+ version strings are malformed.
+ """
+ try:
+ from packaging.version import Version
+
+ a = Version(str(actual))
+ b = Version(str(threshold))
+ if op == "eq":
+ return a == b
+ if op == "lt":
+ return a < b
+ if op == "gt":
+ return a > b
+ except ImportError:
+ logger.warning(
+ "api-shield flags: semver operators require 'packaging'. "
+ "Install with: pip install api-shield[flags]"
+ )
+ except Exception: # noqa: BLE001
+ logger.warning(
+ "api-shield flags: semver comparison failed for values '%s' and '%s'.",
+ actual,
+ threshold,
+ )
+ return False
diff --git a/shield/core/feature_flags/hooks.py b/shield/core/feature_flags/hooks.py
new file mode 100644
index 0000000..f1e6019
--- /dev/null
+++ b/shield/core/feature_flags/hooks.py
@@ -0,0 +1,168 @@
+"""Built-in OpenFeature hooks for api-shield.
+
+All hooks implement OpenFeature's ``Hook`` interface and are registered
+via ``engine.use_openfeature(hooks=[...])``.
+
+Built-in hooks registered by default
+-------------------------------------
+``LoggingHook`` — logs every evaluation at DEBUG level.
+``AuditHook`` — records non-trivial evaluations in ShieldEngine's audit log.
+``MetricsHook`` — increments per-variation counters for dashboard stats.
+
+Optional hooks (user-registered)
+---------------------------------
+``OpenTelemetryHook`` — sets ``feature_flag.*`` span attributes on the
+current OpenTelemetry span. Requires ``opentelemetry-api`` to be installed.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from shield.core.feature_flags._guard import _require_flags
+
+_require_flags()
+
+from openfeature.flag_evaluation import FlagEvaluationDetails, FlagValueType
+from openfeature.hook import Hook, HookContext, HookHints
+
+logger = logging.getLogger(__name__)
+
+
+class LoggingHook(Hook):
+ """Log every flag evaluation at DEBUG level.
+
+ Automatically registered by ``engine.use_openfeature()``.
+ """
+
+ def after(
+ self,
+ hook_context: HookContext,
+ details: FlagEvaluationDetails[FlagValueType],
+ hints: HookHints,
+ ) -> None:
+ logger.debug(
+ "api-shield flag eval: key=%s variant=%s reason=%s",
+ hook_context.flag_key,
+ details.variant,
+ details.reason,
+ )
+
+ def error(
+ self,
+ hook_context: HookContext,
+ exception: Exception,
+ hints: HookHints,
+ ) -> None:
+ logger.error(
+ "api-shield flag error: key=%s error=%s",
+ hook_context.flag_key,
+ exception,
+ )
+
+
+class AuditHook(Hook):
+ """Record flag evaluations in ShieldEngine's audit log.
+
+ Only records evaluations with non-trivial reasons (RULE_MATCH,
+ TARGET_MATCH, PREREQUISITE_FAIL, ERROR) to avoid polluting the audit
+ log with FALLTHROUGH and DEFAULT entries.
+
+ Automatically registered by ``engine.use_openfeature()``.
+
+ Parameters
+ ----------
+ engine:
+ The ``ShieldEngine`` instance to write audit entries to.
+ """
+
+ # Reasons worth recording
+ _RECORD_REASONS = frozenset(["TARGETING_MATCH", "DISABLED", "ERROR"])
+
+ def __init__(self, engine: Any) -> None:
+ self._engine = engine
+
+ def after(
+ self,
+ hook_context: HookContext,
+ details: FlagEvaluationDetails[FlagValueType],
+ hints: HookHints,
+ ) -> None:
+ if details.reason not in self._RECORD_REASONS:
+ return
+ # Fire-and-forget — audit writes are best-effort
+ import asyncio
+ import contextlib
+
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ with contextlib.suppress(Exception):
+ loop.create_task(
+ self._engine.record_flag_evaluation(hook_context.flag_key, details)
+ )
+
+
+class MetricsHook(Hook):
+ """Increment per-variation evaluation counters.
+
+ Parameters
+ ----------
+ collector:
+ ``FlagMetricsCollector`` instance that stores the counters.
+ """
+
+ def __init__(self, collector: Any = None) -> None:
+ self._collector = collector
+
+ def after(
+ self,
+ hook_context: HookContext,
+ details: FlagEvaluationDetails[FlagValueType],
+ hints: HookHints,
+ ) -> None:
+ import asyncio
+ import contextlib
+
+ ctx = hook_context.evaluation_context
+ targeting_key = getattr(ctx, "targeting_key", "anonymous") if ctx else "anonymous"
+
+ record = {
+ "variation": details.variant or "unknown",
+ "reason": details.reason or "UNKNOWN",
+ "context_key": targeting_key,
+ }
+
+ if self._collector is not None:
+ loop = asyncio.get_event_loop()
+ if loop.is_running():
+ with contextlib.suppress(Exception):
+ loop.create_task(self._collector.record(hook_context.flag_key, record))
+
+
+class OpenTelemetryHook(Hook):
+ """Set ``feature_flag.*`` span attributes on the current OTel span.
+
+ No-ops gracefully when ``opentelemetry-api`` is not installed.
+ Optional — register via ``engine.use_openfeature(hooks=[OpenTelemetryHook()])``.
+ """
+
+ def after(
+ self,
+ hook_context: HookContext,
+ details: FlagEvaluationDetails[FlagValueType],
+ hints: HookHints,
+ ) -> None:
+ try:
+ from opentelemetry import trace # type: ignore[import-not-found]
+
+ span = trace.get_current_span()
+ if span.is_recording():
+ key = hook_context.flag_key
+ span.set_attribute(f"feature_flag.{key}.value", str(details.value))
+ if details.variant:
+ span.set_attribute(f"feature_flag.{key}.variant", details.variant)
+ if details.reason:
+ span.set_attribute(f"feature_flag.{key}.reason", details.reason)
+ except ImportError:
+ pass # opentelemetry-api not installed — silently skip
diff --git a/shield/core/feature_flags/models.py b/shield/core/feature_flags/models.py
new file mode 100644
index 0000000..91f3a0b
--- /dev/null
+++ b/shield/core/feature_flags/models.py
@@ -0,0 +1,590 @@
+"""Feature flag data models for api-shield.
+
+All models are pure Pydantic v2 with no dependency on ``openfeature``.
+This module is importable even without the [flags] extra installed.
+
+Design notes
+------------
+``EvaluationContext.all_attributes()`` merges named convenience fields
+(email, ip, country, app_version) with the free-form ``attributes`` dict
+so that rule clauses can reference any of them by name without callers
+having to manually populate ``attributes`` for common fields.
+
+``RolloutVariation.weight`` is out of 100_000 (not 100) to allow
+fine-grained rollouts like 0.1%, 33.33%, etc. — same precision as
+LaunchDarkly. Weights in a rollout list should sum to 100_000.
+
+``FeatureFlag.targets`` maps variation name → list of context keys for
+individual targeting. Evaluated before rules (highest priority after
+prerequisites).
+
+``FeatureFlag.fallthrough`` accepts either a plain variation name
+(``str``) for a fixed default, or a list of ``RolloutVariation`` for a
+percentage-based default rule.
+"""
+
+from __future__ import annotations
+
+import uuid
+from datetime import datetime
+from enum import StrEnum
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+# ── Flag type ────────────────────────────────────────────────────────────────
+
+
+class FlagType(StrEnum):
+ """Value type of a feature flag's variations."""
+
+ BOOLEAN = "boolean"
+ STRING = "string"
+ INTEGER = "integer"
+ FLOAT = "float"
+ JSON = "json"
+
+
+# ── Variations ───────────────────────────────────────────────────────────────
+
+
+class FlagVariation(BaseModel):
+ """A single named variation of a feature flag.
+
+ Parameters
+ ----------
+ name:
+ Identifier used in rules, targets, fallthrough, and prerequisites.
+ E.g. ``"on"``, ``"off"``, ``"control"``, ``"variant_a"``.
+ value:
+ The actual value returned when this variation is served.
+ Must match the flag's ``type``.
+ description:
+ Optional human-readable note shown in the dashboard.
+ """
+
+ name: str
+ value: bool | str | int | float | dict[str, Any] | list[Any]
+ description: str = ""
+
+
+class RolloutVariation(BaseModel):
+ """One bucket in a percentage rollout.
+
+ Parameters
+ ----------
+ variation:
+ References ``FlagVariation.name``.
+ weight:
+ Share of traffic (out of 100_000). All weights in a rollout
+ list should sum to 100_000. E.g. 25% = 25_000.
+ """
+
+ variation: str
+ weight: int = Field(ge=0, le=100_000)
+
+
+# ── Targeting operators ──────────────────────────────────────────────────────
+
+
+class Operator(StrEnum):
+ """All supported targeting rule operators.
+
+ String operators
+ ----------------
+ ``IS`` / ``IS_NOT`` — exact string equality.
+ ``CONTAINS`` / ``NOT_CONTAINS`` — substring match.
+ ``STARTS_WITH`` / ``ENDS_WITH`` — prefix / suffix match.
+ ``MATCHES`` / ``NOT_MATCHES`` — regex match (Python ``re`` module).
+
+ Numeric operators
+ -----------------
+ ``GT`` / ``GTE`` / ``LT`` / ``LTE`` — numeric comparisons.
+
+ Date operators
+ --------------
+ ``BEFORE`` / ``AFTER`` — ISO-8601 string comparisons (lexicographic).
+
+ Collection operators
+ --------------------
+ ``IN`` / ``NOT_IN`` — membership in a list of values.
+
+ Segment operators
+ -----------------
+ ``IN_SEGMENT`` / ``NOT_IN_SEGMENT`` — context is/isn't in a named segment.
+
+ Semantic version operators
+ --------------------------
+ ``SEMVER_EQ`` / ``SEMVER_LT`` / ``SEMVER_GT`` — PEP 440 / semver
+ comparison using ``packaging.version.Version``.
+ Requires ``packaging`` (installed with the [flags] extra).
+ """
+
+ # Equality
+ IS = "is"
+ IS_NOT = "is_not"
+ # String
+ CONTAINS = "contains"
+ NOT_CONTAINS = "not_contains"
+ STARTS_WITH = "starts_with"
+ ENDS_WITH = "ends_with"
+ MATCHES = "matches"
+ NOT_MATCHES = "not_matches"
+ # Numeric
+ GT = "gt"
+ GTE = "gte"
+ LT = "lt"
+ LTE = "lte"
+ # Date
+ BEFORE = "before"
+ AFTER = "after"
+ # Collection
+ IN = "in"
+ NOT_IN = "not_in"
+ # Segment
+ IN_SEGMENT = "in_segment"
+ NOT_IN_SEGMENT = "not_in_segment"
+ # Semantic version
+ SEMVER_EQ = "semver_eq"
+ SEMVER_LT = "semver_lt"
+ SEMVER_GT = "semver_gt"
+
+
+# ── Rules ────────────────────────────────────────────────────────────────────
+
+
+class RuleClause(BaseModel):
+ """A single condition in a targeting rule.
+
+ All clauses within a rule are AND-ed together.
+ Multiple values within one clause are OR-ed (any value must match).
+
+ Parameters
+ ----------
+ attribute:
+ Context attribute to inspect. E.g. ``"role"``, ``"plan"``,
+ ``"email"``, ``"country"``, ``"app_version"``.
+ operator:
+ Comparison operator to apply.
+ values:
+ One or more values to compare against. Multiple values use
+ OR logic — the clause passes if *any* value matches.
+ negate:
+ When ``True``, the result of the operator check is inverted.
+ """
+
+ attribute: str
+ operator: Operator
+ values: list[Any]
+ negate: bool = False
+
+
+class TargetingRule(BaseModel):
+ """A complete targeting rule: all clauses match → serve a variation.
+
+ Parameters
+ ----------
+ id:
+ UUID4 identifier. Used for ordering, references, and scheduling.
+ description:
+ Human-readable label shown in the dashboard.
+ clauses:
+ List of ``RuleClause``. ALL must match (AND logic).
+ variation:
+ Fixed variation name to serve when rule matches.
+ Mutually exclusive with ``rollout``.
+ rollout:
+ Percentage rollout when rule matches.
+ Mutually exclusive with ``variation``.
+ track_events:
+ When ``True``, evaluation events for this rule are always
+ recorded regardless of global event sampling settings.
+ """
+
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ description: str = ""
+ clauses: list[RuleClause] = Field(default_factory=list)
+ variation: str | None = None
+ rollout: list[RolloutVariation] | None = None
+ track_events: bool = False
+
+
+# ── Prerequisites ─────────────────────────────────────────────────────────────
+
+
+class Prerequisite(BaseModel):
+ """A prerequisite flag that must evaluate to a specific variation.
+
+ Parameters
+ ----------
+ flag_key:
+ Key of the prerequisite flag.
+ variation:
+ The variation the prerequisite flag must return.
+ If it returns any other variation, the dependent flag serves
+ its ``off_variation``.
+ """
+
+ flag_key: str
+ variation: str
+
+
+# ── Segments ─────────────────────────────────────────────────────────────────
+
+
+class SegmentRule(BaseModel):
+ """A rule within a segment definition.
+
+ If all clauses match, the context is considered part of the segment.
+ Multiple segment rules are OR-ed (any matching rule → included).
+
+ Parameters
+ ----------
+ id:
+ UUID4 identifier for ordering and deletion.
+ description:
+ Optional human-readable label shown in the dashboard.
+ clauses:
+ List of ``RuleClause``. ALL must match (AND logic) for the rule
+ to match.
+ """
+
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ description: str = ""
+ clauses: list[RuleClause] = Field(default_factory=list)
+
+
+class Segment(BaseModel):
+ """A reusable group of contexts for flag targeting.
+
+ Evaluation order:
+ 1. If ``context.key`` is in ``excluded`` → NOT in segment.
+ 2. If ``context.key`` is in ``included`` → IN segment.
+ 3. Evaluate ``rules`` top-to-bottom — first match → IN segment.
+ 4. No match → NOT in segment.
+
+ Parameters
+ ----------
+ key:
+ Unique identifier. Referenced by ``IN_SEGMENT`` clauses.
+ name:
+ Human-readable display name.
+ included:
+ Explicit context keys always included in this segment.
+ excluded:
+ Explicit context keys always excluded (overrides rules and included).
+ rules:
+ Targeting rules — any matching rule means the context is included.
+ tags:
+ Organisational labels for filtering in the dashboard.
+ """
+
+ key: str
+ name: str
+ description: str = ""
+ included: list[str] = Field(default_factory=list)
+ excluded: list[str] = Field(default_factory=list)
+ rules: list[SegmentRule] = Field(default_factory=list)
+ tags: list[str] = Field(default_factory=list)
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
+
+
+# ── Scheduled changes ─────────────────────────────────────────────────────────
+
+
+class ScheduledChangeAction(StrEnum):
+ """Action to execute at a scheduled time."""
+
+ ENABLE = "enable"
+ DISABLE = "disable"
+ UPDATE_ROLLOUT = "update_rollout"
+ ADD_RULE = "add_rule"
+ DELETE_RULE = "delete_rule"
+
+
+class ScheduledChange(BaseModel):
+ """A pending change to a flag scheduled for future execution.
+
+ Parameters
+ ----------
+ id:
+ UUID4 identifier.
+ execute_at:
+ UTC datetime when the change should fire.
+ action:
+ Which operation to apply to the flag.
+ payload:
+ Action-specific data. E.g. for ``UPDATE_ROLLOUT``::
+
+ {"variation": "on", "weight": 50_000}
+
+ For ``ADD_RULE``: a serialised ``TargetingRule`` dict.
+ For ``DELETE_RULE``: ``{"rule_id": "..."}``.
+ created_by:
+ Actor who scheduled the change (username or ``"system"``).
+ """
+
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ execute_at: datetime
+ action: ScheduledChangeAction
+ payload: dict[str, Any] = Field(default_factory=dict)
+ created_by: str = "system"
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+
+
+# ── Flag lifecycle status ─────────────────────────────────────────────────────
+
+
+class FlagStatus(StrEnum):
+ """Computed lifecycle status of a feature flag.
+
+ Derived from evaluation metrics and configuration — not stored.
+ """
+
+ NEW = "new"
+ """Created recently, never evaluated."""
+
+ ACTIVE = "active"
+ """Being evaluated or recently modified."""
+
+ LAUNCHED = "launched"
+ """Fully rolled out — single variation, stable, safe to clean up."""
+
+ INACTIVE = "inactive"
+ """Not evaluated in 7+ days."""
+
+ DEPRECATED = "deprecated"
+ """Marked deprecated by an operator. Still evaluated if enabled."""
+
+ ARCHIVED = "archived"
+ """Removed from active use. No longer evaluated."""
+
+
+# ── Full flag definition ──────────────────────────────────────────────────────
+
+
+class FeatureFlag(BaseModel):
+ """Full definition of a feature flag.
+
+ Stored in ``ShieldBackend`` alongside ``RouteState``.
+ Backend storage key convention: ``shield:flag:{key}``.
+
+ Parameters
+ ----------
+ key:
+ Unique identifier. Used in code: ``flags.get_boolean_value("my-flag", ...)``.
+ name:
+ Human-readable display name shown in the dashboard.
+ type:
+ Determines valid variation value types.
+ variations:
+ All possible flag values. Must contain at least two variations.
+ off_variation:
+ Variation served when ``enabled=False``. Must match a name in
+ ``variations``.
+ fallthrough:
+ Default rule when no targeting rule matches. Either a fixed
+ variation name (``str``) or a percentage rollout
+ (``list[RolloutVariation]`` summing to 100_000).
+ enabled:
+ Global kill-switch. When ``False``, all requests receive
+ ``off_variation`` regardless of targeting rules.
+ prerequisites:
+ Other flags that must evaluate to specific variations before this
+ flag's rules run. Evaluated recursively. Circular dependencies
+ are prevented at write time.
+ targets:
+ Individual targeting. Maps variation name → list of context keys
+ that always receive that variation. Evaluated after prerequisites,
+ before rules.
+ rules:
+ Targeting rules evaluated top-to-bottom. First match wins.
+ scheduled_changes:
+ Pending future mutations managed by ``FlagScheduler``.
+ temporary:
+ When ``True``, the flag hygiene system may mark it for removal
+ once it reaches ``LAUNCHED`` or ``INACTIVE`` status.
+ maintainer:
+ Username of the person responsible for this flag.
+ """
+
+ key: str
+ name: str
+ description: str = ""
+ type: FlagType
+ tags: list[str] = Field(default_factory=list)
+
+ variations: list[FlagVariation]
+ off_variation: str
+ fallthrough: str | list[RolloutVariation]
+
+ enabled: bool = True
+ prerequisites: list[Prerequisite] = Field(default_factory=list)
+ targets: dict[str, list[str]] = Field(default_factory=dict)
+ rules: list[TargetingRule] = Field(default_factory=list)
+ scheduled_changes: list[ScheduledChange] = Field(default_factory=list)
+
+ status: FlagStatus = FlagStatus.ACTIVE
+ temporary: bool = True
+ maintainer: str | None = None
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
+ created_by: str = "system"
+
+ def get_variation_value(self, name: str) -> Any:
+ """Return the value for the variation with the given name.
+
+ Returns ``None`` if the variation name is not found — callers
+ should validate variation names at write time.
+ """
+ for v in self.variations:
+ if v.name == name:
+ return v.value
+ return None
+
+ def variation_names(self) -> list[str]:
+ """Return all variation names for this flag."""
+ return [v.name for v in self.variations]
+
+
+# ── Evaluation context ────────────────────────────────────────────────────────
+
+
+class EvaluationContext(BaseModel):
+ """Per-request context used for flag targeting.
+
+ This is the primary object application code constructs and passes to
+ ``ShieldFeatureClient.get_*_value()``.
+
+ Parameters
+ ----------
+ key:
+ Required unique identifier for the entity being evaluated.
+ Typically ``user_id``, ``session_id``, or ``org_id``. Used for
+ individual targeting and deterministic rollout bucketing.
+ kind:
+ Context kind. Defaults to ``"user"``. Use ``"organization"``,
+ ``"device"``, or a custom string for non-user contexts.
+ email:
+ Convenience field — accessible in rules as ``"email"`` attribute.
+ ip:
+ Convenience field — accessible in rules as ``"ip"`` attribute.
+ country:
+ Convenience field — accessible in rules as ``"country"`` attribute.
+ app_version:
+ Convenience field — accessible in rules as ``"app_version"``.
+ Use semver operators for version-based targeting.
+ attributes:
+ Arbitrary additional attributes. Keys must be strings.
+ Values can be any JSON-serialisable type.
+
+ Examples
+ --------
+ Minimal context::
+
+ ctx = EvaluationContext(key=request.headers["x-user-id"])
+
+ Rich context::
+
+ ctx = EvaluationContext(
+ key=user.id,
+ kind="user",
+ email=user.email,
+ country=user.country,
+ app_version="2.3.1",
+ attributes={"plan": user.plan, "role": user.role},
+ )
+ """
+
+ key: str
+ kind: str = "user"
+ email: str | None = None
+ ip: str | None = None
+ country: str | None = None
+ app_version: str | None = None
+ attributes: dict[str, Any] = Field(default_factory=dict)
+
+ def all_attributes(self) -> dict[str, Any]:
+ """Merge named convenience fields with ``attributes`` for rule evaluation.
+
+ Named fields take lower priority than ``attributes`` — if the same
+ key appears in both, ``attributes`` wins.
+
+ Returns
+ -------
+ dict[str, Any]
+ Flat dict of all context attributes, including ``"key"`` and
+ ``"kind"`` as first-class attributes.
+ """
+ base: dict[str, Any] = {"key": self.key, "kind": self.kind}
+ for field_name in ("email", "ip", "country", "app_version"):
+ val = getattr(self, field_name)
+ if val is not None:
+ base[field_name] = val
+ return {**base, **self.attributes}
+
+
+# ── Resolution result ─────────────────────────────────────────────────────────
+
+
+class EvaluationReason(StrEnum):
+ """Why a flag returned the value it did.
+
+ Included in ``ResolutionDetails`` for every evaluation.
+ Used by the live events stream, audit hook, and eval debugger.
+ """
+
+ OFF = "OFF"
+ """Flag is globally disabled. ``off_variation`` was served."""
+
+ FALLTHROUGH = "FALLTHROUGH"
+ """No targeting rule matched. Default rule was served."""
+
+ TARGET_MATCH = "TARGET_MATCH"
+ """Context key was in the individual targets list."""
+
+ RULE_MATCH = "RULE_MATCH"
+ """A targeting rule matched. See ``rule_id``."""
+
+ PREREQUISITE_FAIL = "PREREQUISITE_FAIL"
+ """A prerequisite flag did not return the required variation.
+ See ``prerequisite_key``."""
+
+ ERROR = "ERROR"
+ """Provider or evaluation error. Default value was returned."""
+
+ DEFAULT = "DEFAULT"
+ """Flag not found in provider. SDK default was returned."""
+
+
+class ResolutionDetails(BaseModel):
+ """Full result of a feature flag evaluation.
+
+ Application code usually only needs ``.value``. The extra fields
+ are used by hooks, the dashboard live stream, and the eval debugger.
+
+ Parameters
+ ----------
+ value:
+ The resolved flag value.
+ variation:
+ The variation name that was served. ``None`` on error/default.
+ reason:
+ Why this value was returned.
+ rule_id:
+ The ``TargetingRule.id`` that matched. Only set when
+ ``reason == RULE_MATCH``.
+ prerequisite_key:
+ The flag key of the failing prerequisite. Only set when
+ ``reason == PREREQUISITE_FAIL``.
+ error_message:
+ Human-readable error detail. Only set when ``reason == ERROR``.
+ """
+
+ value: Any
+ variation: str | None = None
+ reason: EvaluationReason
+ rule_id: str | None = None
+ prerequisite_key: str | None = None
+ error_message: str | None = None
diff --git a/shield/core/feature_flags/provider.py b/shield/core/feature_flags/provider.py
new file mode 100644
index 0000000..c63e6a4
--- /dev/null
+++ b/shield/core/feature_flags/provider.py
@@ -0,0 +1,199 @@
+"""ShieldOpenFeatureProvider — native OpenFeature provider backed by ShieldBackend.
+
+Phase 2 implementation. Stub present so the package imports cleanly.
+Full implementation wires FlagEvaluator into the OpenFeature resolution API.
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from shield.core.feature_flags._guard import _require_flags
+
+_require_flags()
+
+from openfeature.exception import ErrorCode
+from openfeature.flag_evaluation import FlagResolutionDetails, Reason
+from openfeature.provider import AbstractProvider
+from openfeature.provider.metadata import Metadata
+
+from shield.core.feature_flags._context import from_of_context
+from shield.core.feature_flags.evaluator import FlagEvaluator
+from shield.core.feature_flags.models import (
+ EvaluationReason,
+ FeatureFlag,
+ Segment,
+)
+
+if TYPE_CHECKING:
+ from shield.core.backends.base import ShieldBackend
+
+logger = logging.getLogger(__name__)
+
+# Map Shield EvaluationReason → OpenFeature Reason string
+_REASON_MAP: dict[EvaluationReason, str] = {
+ EvaluationReason.OFF: Reason.DISABLED,
+ EvaluationReason.FALLTHROUGH: Reason.DEFAULT,
+ EvaluationReason.TARGET_MATCH: Reason.TARGETING_MATCH,
+ EvaluationReason.RULE_MATCH: Reason.TARGETING_MATCH,
+ EvaluationReason.PREREQUISITE_FAIL: Reason.DISABLED,
+ EvaluationReason.ERROR: Reason.ERROR,
+ EvaluationReason.DEFAULT: Reason.DEFAULT,
+}
+
+
+class ShieldOpenFeatureProvider(AbstractProvider):
+ """OpenFeature-compliant provider backed by ``ShieldBackend``.
+
+ Stores ``FeatureFlag`` and ``Segment`` objects in the same backend
+ as ``RouteState`` — no separate infrastructure required.
+
+ Subscribes to backend pub/sub for instant hot-reload on flag changes.
+ Evaluates flags locally using ``FlagEvaluator`` — zero network calls
+ per evaluation.
+
+ Parameters
+ ----------
+ backend:
+ The ``ShieldBackend`` instance (Memory, File, or Redis).
+ Must be the same instance passed to ``ShieldEngine``.
+ """
+
+ def __init__(self, backend: ShieldBackend) -> None:
+ self._backend = backend
+ self._flags: dict[str, FeatureFlag] = {}
+ self._segments: dict[str, Segment] = {}
+ self._evaluator = FlagEvaluator(segments=self._segments)
+
+ def get_metadata(self) -> Metadata:
+ return Metadata(name="shield")
+
+ def get_provider_hooks(self) -> list[Any]:
+ return []
+
+ def initialize(self, evaluation_context: Any = None) -> None:
+ """No-op sync hook required by the OpenFeature SDK registry.
+
+ The OpenFeature SDK calls this synchronously when ``set_provider()``
+ is invoked. Actual async initialisation (loading flags from the
+ backend) is performed by ``engine.start()`` via ``_load_all()``.
+ """
+
+ def shutdown(self) -> None:
+ """No-op sync hook required by the OpenFeature SDK registry."""
+
+ async def _load_all(self) -> None:
+ """Load all flags and segments from backend into local cache."""
+ try:
+ flags = await self._backend.load_all_flags()
+ self._flags = {f.key: f for f in flags}
+ segments = await self._backend.load_all_segments()
+ self._segments.update({s.key: s for s in segments})
+ except AttributeError:
+ # Backend does not yet support flag storage (pre-Phase 3 backends).
+ # Operate with empty caches — all evaluations return defaults.
+ logger.debug(
+ "api-shield flags: backend does not support flag storage yet. "
+ "All flag evaluations will return defaults."
+ )
+
+ # ── OpenFeature resolution methods ──────────────────────────────────────
+
+ def resolve_boolean_details(
+ self, flag_key: str, default_value: bool, evaluation_context: Any = None
+ ) -> FlagResolutionDetails[Any]:
+ return self._resolve(flag_key, default_value, evaluation_context, bool)
+
+ def resolve_string_details(
+ self, flag_key: str, default_value: str, evaluation_context: Any = None
+ ) -> FlagResolutionDetails[Any]:
+ return self._resolve(flag_key, default_value, evaluation_context, str)
+
+ def resolve_integer_details(
+ self, flag_key: str, default_value: int, evaluation_context: Any = None
+ ) -> FlagResolutionDetails[Any]:
+ return self._resolve(flag_key, default_value, evaluation_context, int)
+
+ def resolve_float_details(
+ self, flag_key: str, default_value: float, evaluation_context: Any = None
+ ) -> FlagResolutionDetails[Any]:
+ return self._resolve(flag_key, default_value, evaluation_context, float)
+
+ def resolve_object_details( # type: ignore[override]
+ self,
+ flag_key: str,
+ default_value: dict[str, Any],
+ evaluation_context: Any = None,
+ ) -> FlagResolutionDetails[Any]:
+ return self._resolve(flag_key, default_value, evaluation_context, dict)
+
+ # ── Internal ────────────────────────────────────────────────────────────
+
+ def _resolve(
+ self,
+ flag_key: str,
+ default_value: Any,
+ of_ctx: Any,
+ expected_type: type,
+ ) -> FlagResolutionDetails[Any]:
+ flag = self._flags.get(flag_key)
+ if flag is None:
+ return FlagResolutionDetails(
+ value=default_value,
+ reason=Reason.DEFAULT,
+ error_code=ErrorCode.FLAG_NOT_FOUND,
+ error_message=f"Flag '{flag_key}' not found",
+ )
+
+ ctx = from_of_context(of_ctx)
+ try:
+ result = self._evaluator.evaluate(flag, ctx, self._flags)
+ except Exception as exc: # noqa: BLE001
+ logger.exception("api-shield flags: evaluation error for '%s'", flag_key)
+ return FlagResolutionDetails(
+ value=default_value,
+ reason=Reason.ERROR,
+ error_code=ErrorCode.GENERAL,
+ error_message=str(exc),
+ )
+
+ value = result.value
+ # Type coercion — ensure returned value matches the expected type
+ if value is None:
+ value = default_value
+ else:
+ try:
+ value = expected_type(value)
+ except (TypeError, ValueError):
+ value = default_value
+
+ flag_metadata: dict[str, int | float | str] = {}
+ if result.rule_id is not None:
+ flag_metadata["rule_id"] = result.rule_id
+ if result.prerequisite_key is not None:
+ flag_metadata["prerequisite_key"] = result.prerequisite_key
+ return FlagResolutionDetails(
+ value=value,
+ variant=result.variation,
+ reason=_REASON_MAP.get(result.reason, Reason.UNKNOWN),
+ flag_metadata=flag_metadata,
+ )
+
+ # ── Flag cache management (called by engine on flag CRUD) ────────────────
+
+ def upsert_flag(self, flag: FeatureFlag) -> None:
+ """Update or insert a flag in the local cache."""
+ self._flags[flag.key] = flag
+
+ def delete_flag(self, flag_key: str) -> None:
+ """Remove a flag from the local cache."""
+ self._flags.pop(flag_key, None)
+
+ def upsert_segment(self, segment: Segment) -> None:
+ """Update or insert a segment in the local cache."""
+ self._segments[segment.key] = segment
+
+ def delete_segment(self, segment_key: str) -> None:
+ """Remove a segment from the local cache."""
+ self._segments.pop(segment_key, None)
diff --git a/shield/core/feature_flags/scheduler.py b/shield/core/feature_flags/scheduler.py
new file mode 100644
index 0000000..8905740
--- /dev/null
+++ b/shield/core/feature_flags/scheduler.py
@@ -0,0 +1,233 @@
+"""FlagScheduler — asyncio.Task-based scheduled flag change runner.
+
+Each :class:`ScheduledChange` on a :class:`FeatureFlag` gets one asyncio
+task that sleeps until ``execute_at``, applies the action to the flag, then
+removes the change from the flag's ``scheduled_changes`` list.
+
+On startup the scheduler scans all flags and re-creates tasks for any
+pending changes whose ``execute_at`` is still in the future (restart
+recovery).
+
+Supported :class:`~shield.core.feature_flags.models.ScheduledChangeAction`\\ s:
+
+* ``ENABLE`` — sets ``flag.enabled = True``
+* ``DISABLE`` — sets ``flag.enabled = False``
+* ``UPDATE_ROLLOUT`` — replaces ``flag.fallthrough`` with a new variation
+ name or rollout list from ``payload``
+* ``ADD_RULE`` — appends a :class:`TargetingRule` parsed from ``payload``
+* ``DELETE_RULE`` — removes the rule with ``payload["rule_id"]``
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+from datetime import UTC, datetime
+from typing import TYPE_CHECKING, Any
+
+import anyio
+
+if TYPE_CHECKING:
+ from shield.core.engine import ShieldEngine
+
+logger = logging.getLogger(__name__)
+
+
+class FlagScheduler:
+ """Manages scheduled flag changes using ``asyncio.Task`` objects.
+
+ Parameters
+ ----------
+ engine:
+ The :class:`~shield.core.engine.ShieldEngine` used to read and
+ write flags.
+ """
+
+ def __init__(self, engine: ShieldEngine) -> None:
+ self._engine = engine
+ # (flag_key, change_id) → running task
+ self._tasks: dict[tuple[str, str], asyncio.Task[None]] = {}
+
+ # ------------------------------------------------------------------
+ # Lifecycle
+ # ------------------------------------------------------------------
+
+ async def start(self) -> None:
+ """Restore pending scheduled changes from all flags.
+
+ Called by ``ShieldEngine.start()`` when feature flags are enabled.
+ """
+ try:
+ flags = await self._engine.list_flags()
+ except Exception:
+ logger.exception("FlagScheduler: failed to load flags on startup")
+ return
+
+ now = datetime.now(UTC)
+ count = 0
+ for flag in flags:
+ for change in list(flag.scheduled_changes):
+ execute_at = change.execute_at
+ if execute_at.tzinfo is None:
+ execute_at = execute_at.replace(tzinfo=UTC)
+ if execute_at > now:
+ self._create_task(flag.key, change)
+ count += 1
+ if count:
+ logger.info("FlagScheduler: restored %d pending scheduled change(s)", count)
+
+ async def stop(self) -> None:
+ """Cancel all pending scheduled change tasks."""
+ for task in list(self._tasks.values()):
+ task.cancel()
+ try:
+ await task
+ except (asyncio.CancelledError, Exception):
+ pass
+ self._tasks.clear()
+
+ # ------------------------------------------------------------------
+ # Public API
+ # ------------------------------------------------------------------
+
+ async def schedule(self, flag_key: str, change: Any) -> None:
+ """Register a new scheduled change task.
+
+ If a task already exists for ``(flag_key, change.id)`` it is
+ cancelled and replaced.
+
+ Parameters
+ ----------
+ flag_key:
+ Key of the flag that owns the change.
+ change:
+ A :class:`~shield.core.feature_flags.models.ScheduledChange`
+ instance already appended to the flag's ``scheduled_changes``
+ list and persisted to the backend.
+ """
+ await self.cancel(flag_key, change.id)
+ self._create_task(flag_key, change)
+
+ async def cancel(self, flag_key: str, change_id: str) -> None:
+ """Cancel the task for a specific scheduled change, if any."""
+ task = self._tasks.pop((flag_key, change_id), None)
+ if task is not None:
+ task.cancel()
+ try:
+ await task
+ except (asyncio.CancelledError, Exception):
+ pass
+
+ async def cancel_all_for_flag(self, flag_key: str) -> None:
+ """Cancel all pending tasks for *flag_key* (e.g. when a flag is deleted)."""
+ keys_to_cancel = [k for k in self._tasks if k[0] == flag_key]
+ for k in keys_to_cancel:
+ task = self._tasks.pop(k)
+ task.cancel()
+ try:
+ await task
+ except (asyncio.CancelledError, Exception):
+ pass
+
+ def list_pending(self) -> list[dict[str, str]]:
+ """Return a list of ``{"flag_key": ..., "change_id": ...}`` dicts."""
+ return [{"flag_key": fk, "change_id": cid} for fk, cid in self._tasks]
+
+ # ------------------------------------------------------------------
+ # Internals
+ # ------------------------------------------------------------------
+
+ def _create_task(self, flag_key: str, change: Any) -> asyncio.Task[None]:
+ task = asyncio.create_task(
+ self._run_change(flag_key, change),
+ name=f"shield-flag-scheduler:{flag_key}:{change.id}",
+ )
+ self._tasks[(flag_key, change.id)] = task
+ task.add_done_callback(lambda t: self._tasks.pop((flag_key, change.id), None))
+ return task
+
+ async def _run_change(self, flag_key: str, change: Any) -> None:
+ """Sleep until ``execute_at``, then apply the change to the flag."""
+ execute_at = change.execute_at
+ if execute_at.tzinfo is None:
+ execute_at = execute_at.replace(tzinfo=UTC)
+
+ now = datetime.now(UTC)
+ delay = (execute_at - now).total_seconds()
+ if delay > 0:
+ try:
+ await anyio.sleep(delay)
+ except asyncio.CancelledError:
+ return
+
+ logger.info(
+ "FlagScheduler: executing change %s (action=%s) on flag %r",
+ change.id,
+ change.action,
+ flag_key,
+ )
+ try:
+ await self._apply_change(flag_key, change)
+ except Exception:
+ logger.exception(
+ "FlagScheduler: error applying change %s on flag %r", change.id, flag_key
+ )
+
+ async def _apply_change(self, flag_key: str, change: Any) -> None:
+ """Load the flag, mutate it, remove the change, and persist."""
+ from shield.core.feature_flags.models import ScheduledChangeAction, TargetingRule
+
+ flag = await self._engine.get_flag(flag_key)
+ if flag is None:
+ logger.warning(
+ "FlagScheduler: flag %r not found when applying change %s — skipping",
+ flag_key,
+ change.id,
+ )
+ return
+
+ action = change.action
+ payload = change.payload or {}
+
+ if action == ScheduledChangeAction.ENABLE:
+ flag = flag.model_copy(update={"enabled": True})
+ elif action == ScheduledChangeAction.DISABLE:
+ flag = flag.model_copy(update={"enabled": False})
+ elif action == ScheduledChangeAction.UPDATE_ROLLOUT:
+ new_fallthrough = payload.get("variation") or payload.get("rollout")
+ if new_fallthrough is not None:
+ flag = flag.model_copy(update={"fallthrough": new_fallthrough})
+ else:
+ logger.warning(
+ "FlagScheduler: UPDATE_ROLLOUT payload missing 'variation' for flag %r",
+ flag_key,
+ )
+ elif action == ScheduledChangeAction.ADD_RULE:
+ try:
+ new_rule = TargetingRule.model_validate(payload)
+ updated_rules = list(flag.rules) + [new_rule]
+ flag = flag.model_copy(update={"rules": updated_rules})
+ except Exception as exc:
+ logger.error(
+ "FlagScheduler: ADD_RULE payload invalid for flag %r: %s", flag_key, exc
+ )
+ return
+ elif action == ScheduledChangeAction.DELETE_RULE:
+ rule_id = payload.get("rule_id")
+ updated_rules = [r for r in flag.rules if r.id != rule_id]
+ flag = flag.model_copy(update={"rules": updated_rules})
+ else:
+ logger.warning("FlagScheduler: unknown action %r for flag %r", action, flag_key)
+ return
+
+ # Remove the executed change from the flag's scheduled_changes list.
+ remaining = [c for c in flag.scheduled_changes if c.id != change.id]
+ flag = flag.model_copy(update={"scheduled_changes": remaining})
+ await self._engine.save_flag(flag)
+
+ logger.info(
+ "FlagScheduler: applied %s to flag %r (change %s)",
+ action,
+ flag_key,
+ change.id,
+ )
diff --git a/shield/dashboard/routes.py b/shield/dashboard/routes.py
index b353ec9..fceb9cf 100644
--- a/shield/dashboard/routes.py
+++ b/shield/dashboard/routes.py
@@ -1040,6 +1040,8 @@ async def _generate() -> object:
# Keepalive ping loop — runs when subscribe() is unsupported OR after
# the subscription ends. Browsers keep the connection alive.
while True:
+ if await request.is_disconnected():
+ break
yield ": keepalive\n\n"
try:
await anyio.sleep(15)
@@ -1054,3 +1056,851 @@ async def _generate() -> object:
"X-Accel-Buffering": "no",
},
)
+
+
+# ---------------------------------------------------------------------------
+# Feature flag dashboard pages
+# ---------------------------------------------------------------------------
+
+_FLAG_TYPE_COLOURS = {
+ "boolean": "emerald",
+ "string": "blue",
+ "integer": "violet",
+ "float": "violet",
+ "json": "amber",
+}
+
+
+async def flags_page(request: Request) -> Response:
+ """GET /flags — feature flag list page."""
+ tpl = _templates(request)
+ engine = _engine(request)
+ prefix = _prefix(request)
+ flags = await engine.list_flags()
+ return tpl.TemplateResponse(
+ request,
+ "flags.html",
+ {
+ "prefix": prefix,
+ "flags": flags,
+ "active_tab": "flags",
+ "shield_actor": _actor(request),
+ "version": request.app.state.version,
+ "flag_type_colours": _FLAG_TYPE_COLOURS,
+ "flags_enabled": True,
+ },
+ )
+
+
+async def flags_rows_partial(request: Request) -> Response:
+ """GET /flags/rows — HTMX partial: flag table rows only.
+
+ Supports ``?q=`` search query and ``?type=`` / ``?status=`` filters.
+ """
+ tpl = _templates(request)
+ engine = _engine(request)
+ prefix = _prefix(request)
+ flags = await engine.list_flags()
+
+ q = request.query_params.get("q", "").lower().strip()
+ ftype = request.query_params.get("type", "").strip()
+ status_filter = request.query_params.get("status", "").strip()
+
+ if q:
+ flags = [f for f in flags if q in f.key.lower() or q in (f.name or "").lower()]
+ if ftype:
+ flags = [f for f in flags if f.type.value == ftype]
+ if status_filter == "enabled":
+ flags = [f for f in flags if f.enabled]
+ elif status_filter == "disabled":
+ flags = [f for f in flags if not f.enabled]
+
+ return tpl.TemplateResponse(
+ request,
+ "partials/flag_rows.html",
+ {
+ "prefix": prefix,
+ "flags": flags,
+ "flag_type_colours": _FLAG_TYPE_COLOURS,
+ },
+ )
+
+
+async def flag_detail_page(request: Request) -> Response:
+ """GET /flags/{key} — single flag detail page."""
+ tpl = _templates(request)
+ engine = _engine(request)
+ prefix = _prefix(request)
+ key = request.path_params["key"]
+ flag = await engine.get_flag(key)
+ if flag is None:
+ return HTMLResponse("