Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 25 additions & 181 deletions cmd/thv-operator/controllers/mcpserver_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/api/resource"
Expand All @@ -32,6 +31,7 @@
"sigs.k8s.io/controller-runtime/pkg/log"

mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
"github.com/stacklok/toolhive/cmd/thv-operator/pkg/rbac"
"github.com/stacklok/toolhive/cmd/thv-operator/pkg/validation"
"github.com/stacklok/toolhive/pkg/container/kubernetes"
)
Expand All @@ -44,43 +44,10 @@
detectedPlatform kubernetes.Platform
platformOnce sync.Once
ImageValidation validation.ImageValidation
RBACManager rbac.Manager
}

// defaultRBACRules are the default RBAC rules that the
// ToolHive ProxyRunner and/or MCP server needs to have in order to run.
var defaultRBACRules = []rbacv1.PolicyRule{
{
APIGroups: []string{"apps"},
Resources: []string{"statefulsets"},
Verbs: []string{"get", "list", "watch", "create", "update", "patch", "delete", "apply"},
},
{
APIGroups: []string{""},
Resources: []string{"services"},
Verbs: []string{"get", "list", "watch", "create", "update", "patch", "delete", "apply"},
},
{
APIGroups: []string{""},
Resources: []string{"pods"},
Verbs: []string{"get", "list", "watch"},
},
{
APIGroups: []string{""},
Resources: []string{"pods/log"},
Verbs: []string{"get"},
},
{
APIGroups: []string{""},
Resources: []string{"pods/attach"},
Verbs: []string{"create", "get"},
},
{
APIGroups: []string{""},
Resources: []string{"configmaps"},
Verbs: []string{"get", "list", "watch"},
},
}

Check failure on line 50 in cmd/thv-operator/controllers/mcpserver_controller.go

View workflow job for this annotation

GitHub Actions / Linting / Lint Go Code

File is not properly formatted (gci)
// mcpContainerName is the name of the mcp container used in pod templates
const mcpContainerName = "mcp"

Expand Down Expand Up @@ -545,82 +512,6 @@
return nil
}

// ensureRBACResource is a generic helper function to ensure a Kubernetes resource exists and is up to date
func (r *MCPServerReconciler) ensureRBACResource(
ctx context.Context,
mcpServer *mcpv1alpha1.MCPServer,
resourceType string,
createResource func() client.Object,
) error {
current := createResource()
objectKey := types.NamespacedName{Name: current.GetName(), Namespace: current.GetNamespace()}
err := r.Get(ctx, objectKey, current)

if errors.IsNotFound(err) {
return r.createRBACResource(ctx, mcpServer, resourceType, createResource)
} else if err != nil {
return fmt.Errorf("failed to get %s: %w", resourceType, err)
}

return r.updateRBACResourceIfNeeded(ctx, mcpServer, resourceType, createResource, current)
}

// createRBACResource creates a new RBAC resource
func (r *MCPServerReconciler) createRBACResource(
ctx context.Context,
mcpServer *mcpv1alpha1.MCPServer,
resourceType string,
createResource func() client.Object,
) error {
ctxLogger := log.FromContext(ctx)
desired := createResource()
if err := controllerutil.SetControllerReference(mcpServer, desired, r.Scheme); err != nil {
ctxLogger.Error(err, "Failed to set controller reference", "resourceType", resourceType)
return nil
}

ctxLogger.Info(
fmt.Sprintf("%s does not exist, creating %s", resourceType, resourceType),
fmt.Sprintf("%s.Name", resourceType),
desired.GetName(),
)
if err := r.Create(ctx, desired); err != nil {
return fmt.Errorf("failed to create %s: %w", resourceType, err)
}
ctxLogger.Info(fmt.Sprintf("%s created", resourceType), fmt.Sprintf("%s.Name", resourceType), desired.GetName())
return nil
}

// updateRBACResourceIfNeeded updates an RBAC resource if changes are detected
func (r *MCPServerReconciler) updateRBACResourceIfNeeded(
ctx context.Context,
mcpServer *mcpv1alpha1.MCPServer,
resourceType string,
createResource func() client.Object,
current client.Object,
) error {
ctxLogger := log.FromContext(ctx)
desired := createResource()
if err := controllerutil.SetControllerReference(mcpServer, desired, r.Scheme); err != nil {
ctxLogger.Error(err, "Failed to set controller reference", "resourceType", resourceType)
return nil
}

if !reflect.DeepEqual(current, desired) {
ctxLogger.Info(
fmt.Sprintf("%s exists, updating %s", resourceType, resourceType),
fmt.Sprintf("%s.Name", resourceType),
desired.GetName(),
)
if err := r.Update(ctx, desired); err != nil {
return fmt.Errorf("failed to update %s: %w", resourceType, err)
}
ctxLogger.Info(fmt.Sprintf("%s updated", resourceType), fmt.Sprintf("%s.Name", resourceType), desired.GetName())
}
return nil
}

// ensureRBACResources ensures that the RBAC resources are in place for the MCP server

// handleToolConfig handles MCPToolConfig reference for an MCPServer
func (r *MCPServerReconciler) handleToolConfig(ctx context.Context, m *mcpv1alpha1.MCPServer) error {
Expand Down Expand Up @@ -667,71 +558,16 @@
return nil
}
func (r *MCPServerReconciler) ensureRBACResources(ctx context.Context, mcpServer *mcpv1alpha1.MCPServer) error {
proxyRunnerNameForRBAC := proxyRunnerServiceAccountName(mcpServer.Name)

// Ensure Role
if err := r.ensureRBACResource(ctx, mcpServer, "Role", func() client.Object {
return &rbacv1.Role{
ObjectMeta: metav1.ObjectMeta{
Name: proxyRunnerNameForRBAC,
Namespace: mcpServer.Namespace,
},
Rules: defaultRBACRules,
}
}); err != nil {
return err
}

// Ensure ServiceAccount
if err := r.ensureRBACResource(ctx, mcpServer, "ServiceAccount", func() client.Object {
return &corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: proxyRunnerNameForRBAC,
Namespace: mcpServer.Namespace,
},
}
}); err != nil {
return err
}

if err := r.ensureRBACResource(ctx, mcpServer, "RoleBinding", func() client.Object {
return &rbacv1.RoleBinding{
ObjectMeta: metav1.ObjectMeta{
Name: proxyRunnerNameForRBAC,
Namespace: mcpServer.Namespace,
},
RoleRef: rbacv1.RoleRef{
APIGroup: "rbac.authorization.k8s.io",
Kind: "Role",
Name: proxyRunnerNameForRBAC,
},
Subjects: []rbacv1.Subject{
{
Kind: "ServiceAccount",
Name: proxyRunnerNameForRBAC,
Namespace: mcpServer.Namespace,
},
},
}
}); err != nil {
return err
}

// If a service account is specified, we don't need to create one
if mcpServer.Spec.ServiceAccount != nil {
return nil
// Initialize RBACManager if not already initialized
if r.RBACManager == nil {
r.RBACManager = rbac.NewManager(rbac.Config{
Client: r.Client,
Scheme: r.Scheme,
DefaultRBACRules: nil, // Use default rules from the package
})
}

// otherwise, create a service account for the MCP server
mcpServerServiceAccountName := mcpServerServiceAccountName(mcpServer.Name)
return r.ensureRBACResource(ctx, mcpServer, "ServiceAccount", func() client.Object {
return &corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: mcpServerServiceAccountName,
Namespace: mcpServer.Namespace,
},
}
})
return r.RBACManager.EnsureRBACResources(ctx, mcpServer)
}

// deploymentForMCPServer returns a MCPServer Deployment object
Expand All @@ -755,7 +591,7 @@
// If service account is not specified, use the default MCP server service account
serviceAccount := m.Spec.ServiceAccount
if serviceAccount == nil {
defaultSA := mcpServerServiceAccountName(m.Name)
defaultSA := r.mcpServerServiceAccountName(m.Name)
serviceAccount = &defaultSA
}
finalPodTemplateSpec := NewMCPServerPodTemplateSpecBuilder(m.Spec.PodTemplateSpec).
Expand Down Expand Up @@ -817,7 +653,7 @@
// If service account is not specified, use the default MCP server service account
serviceAccount := m.Spec.ServiceAccount
if serviceAccount == nil {
defaultSA := mcpServerServiceAccountName(m.Name)
defaultSA := r.mcpServerServiceAccountName(m.Name)
serviceAccount = &defaultSA
}
finalPodTemplateSpec := NewMCPServerPodTemplateSpecBuilder(m.Spec.PodTemplateSpec).
Expand Down Expand Up @@ -1065,7 +901,7 @@
Annotations: deploymentTemplateAnnotations,
},
Spec: corev1.PodSpec{
ServiceAccountName: proxyRunnerServiceAccountName(m.Name),
ServiceAccountName: r.proxyRunnerServiceAccountName(m.Name),
Containers: []corev1.Container{{
Image: getToolhiveRunnerImage(),
Name: "toolhive",
Expand Down Expand Up @@ -1466,7 +1302,7 @@
// If service account is not specified, use the default MCP server service account
serviceAccount := mcpServer.Spec.ServiceAccount
if serviceAccount == nil {
defaultSA := mcpServerServiceAccountName(mcpServer.Name)
defaultSA := r.mcpServerServiceAccountName(mcpServer.Name)
serviceAccount = &defaultSA
}
expectedPodTemplateSpec := NewMCPServerPodTemplateSpecBuilder(mcpServer.Spec.PodTemplateSpec).
Expand Down Expand Up @@ -1540,7 +1376,7 @@

// Check if the service account name has changed
// ServiceAccountName: treat empty (not yet set) as equal to the expected default
expectedServiceAccountName := proxyRunnerServiceAccountName(mcpServer.Name)
expectedServiceAccountName := r.proxyRunnerServiceAccountName(mcpServer.Name)
currentServiceAccountName := deployment.Spec.Template.Spec.ServiceAccountName
if currentServiceAccountName != "" && currentServiceAccountName != expectedServiceAccountName {
return true
Expand Down Expand Up @@ -1626,12 +1462,20 @@
}

// proxyRunnerServiceAccountName returns the service account name for the proxy runner
func proxyRunnerServiceAccountName(mcpServerName string) string {
func (r *MCPServerReconciler) proxyRunnerServiceAccountName(mcpServerName string) string {
if r.RBACManager != nil {
return r.RBACManager.GetProxyRunnerServiceAccountName(mcpServerName)
}
// Fallback for cases where RBACManager is not initialized yet
return fmt.Sprintf("%s-proxy-runner", mcpServerName)
}

// mcpServerServiceAccountName returns the service account name for the mcp server
func mcpServerServiceAccountName(mcpServerName string) string {
func (r *MCPServerReconciler) mcpServerServiceAccountName(mcpServerName string) string {
if r.RBACManager != nil {
return r.RBACManager.GetMCPServerServiceAccountName(mcpServerName)
}
// Fallback for cases where RBACManager is not initialized yet
return fmt.Sprintf("%s-sa", mcpServerName)
}

Expand Down
16 changes: 12 additions & 4 deletions cmd/thv-operator/controllers/mcpserver_rbac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client/fake"

mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
"github.com/stacklok/toolhive/cmd/thv-operator/pkg/rbac"
)

type testContext struct {
Expand All @@ -31,12 +32,19 @@ func setupTest(name, namespace string) *testContext {
testScheme := createTestScheme()
fakeClient := fake.NewClientBuilder().WithScheme(testScheme).Build()
proxyRunnerNameForRBAC := fmt.Sprintf("%s-proxy-runner", name)
rbacManager := rbac.NewManager(rbac.Config{
Client: fakeClient,
Scheme: testScheme,
DefaultRBACRules: nil, // Use default rules
})

return &testContext{
mcpServer: mcpServer,
client: fakeClient,
reconciler: &MCPServerReconciler{
Client: fakeClient,
Scheme: testScheme,
Client: fakeClient,
Scheme: testScheme,
RBACManager: rbacManager,
},
proxyRunnerNameForRBAC: proxyRunnerNameForRBAC,
}
Expand Down Expand Up @@ -68,7 +76,7 @@ func (tc *testContext) assertRoleExists(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, tc.proxyRunnerNameForRBAC, role.Name)
assert.Equal(t, tc.mcpServer.Namespace, role.Namespace)
assert.Equal(t, defaultRBACRules, role.Rules)
assert.Equal(t, rbac.GetDefaultRBACRules(), role.Rules)
}

func (tc *testContext) assertRoleBindingExists(t *testing.T) {
Expand Down Expand Up @@ -276,7 +284,7 @@ func TestEnsureRBACResources_NoChangesNeeded(t *testing.T) {
Name: tc.proxyRunnerNameForRBAC,
Namespace: tc.mcpServer.Namespace,
},
Rules: defaultRBACRules,
Rules: rbac.GetDefaultRBACRules(),
}
err = tc.client.Create(context.TODO(), role)
require.NoError(t, err)
Expand Down
9 changes: 9 additions & 0 deletions cmd/thv-operator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
"github.com/stacklok/toolhive/cmd/thv-operator/controllers"
"github.com/stacklok/toolhive/cmd/thv-operator/pkg/rbac"
"github.com/stacklok/toolhive/cmd/thv-operator/pkg/validation"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/operator/telemetry"
Expand Down Expand Up @@ -75,10 +76,18 @@ func main() {
os.Exit(1)
}

// Initialize RBAC Manager
rbacManager := rbac.NewManager(rbac.Config{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
DefaultRBACRules: nil, // Use default rules from the package
})

rec := &controllers.MCPServerReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
ImageValidation: validation.ImageValidationAlwaysAllow,
RBACManager: rbacManager,
}

if err = rec.SetupWithManager(mgr); err != nil {
Expand Down
Loading
Loading