diff --git a/executor_test.go b/collection_test.go similarity index 68% rename from executor_test.go rename to collection_test.go index e7da168..0cb13ac 100644 --- a/executor_test.go +++ b/collection_test.go @@ -2,49 +2,26 @@ package gomongo_test import ( "context" - "slices" "strings" "testing" "time" "github.com/bytebase/gomongo" + "github.com/bytebase/gomongo/internal/testutil" "github.com/stretchr/testify/require" - "github.com/testcontainers/testcontainers-go/modules/mongodb" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" - "go.mongodb.org/mongo-driver/v2/mongo/options" ) -func setupTestContainer(t *testing.T) (*mongo.Client, func()) { - ctx := context.Background() - - mongodbContainer, err := mongodb.Run(ctx, "mongo:7") - require.NoError(t, err) - - connectionString, err := mongodbContainer.ConnectionString(ctx) - require.NoError(t, err) - - client, err := mongo.Connect(options.Client().ApplyURI(connectionString)) - require.NoError(t, err) - - cleanup := func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - _ = client.Disconnect(ctx) - _ = mongodbContainer.Terminate(ctx) - } - - return client, cleanup -} - func TestFindEmptyCollection(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_find_empty" + defer testutil.CleanupDatabase(t, client, dbName) gc := gomongo.NewClient(client) ctx := context.Background() - result, err := gc.Execute(ctx, "testdb", "db.users.find()") + result, err := gc.Execute(ctx, dbName, "db.users.find()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 0, result.RowCount) @@ -52,13 +29,14 @@ func TestFindEmptyCollection(t *testing.T) { } func TestFindWithDocuments(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_find_docs" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Insert test documents - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30}, bson.M{"name": "bob", "age": 25}, @@ -66,7 +44,7 @@ func TestFindWithDocuments(t *testing.T) { require.NoError(t, err) gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", "db.users.find()") + result, err := gc.Execute(ctx, dbName, "db.users.find()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 2, result.RowCount) @@ -81,30 +59,32 @@ func TestFindWithDocuments(t *testing.T) { } func TestFindWithEmptyFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_find_empty_filter" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("items") + collection := client.Database(dbName).Collection("items") _, err := collection.InsertOne(ctx, bson.M{"item": "test"}) require.NoError(t, err) gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", "db.items.find({})") + result, err := gc.Execute(ctx, dbName, "db.items.find({})") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) } func TestFindOneEmptyCollection(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_findone_empty" + defer testutil.CleanupDatabase(t, client, dbName) gc := gomongo.NewClient(client) ctx := context.Background() - result, err := gc.Execute(ctx, "testdb", "db.users.findOne()") + result, err := gc.Execute(ctx, dbName, "db.users.findOne()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 0, result.RowCount) @@ -112,12 +92,13 @@ func TestFindOneEmptyCollection(t *testing.T) { } func TestFindOneWithDocuments(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_findone_docs" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30}, bson.M{"name": "bob", "age": 25}, @@ -125,7 +106,7 @@ func TestFindOneWithDocuments(t *testing.T) { require.NoError(t, err) gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", "db.users.findOne()") + result, err := gc.Execute(ctx, dbName, "db.users.findOne()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) @@ -136,12 +117,13 @@ func TestFindOneWithDocuments(t *testing.T) { } func TestFindOneWithFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_findone_filter" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30}, bson.M{"name": "bob", "age": 25}, @@ -191,7 +173,7 @@ func TestFindOneWithFilter(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) if tc.expectMatch { @@ -207,12 +189,13 @@ func TestFindOneWithFilter(t *testing.T) { } func TestFindOneWithOptions(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_findone_options" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("items") + collection := client.Database(dbName).Collection("items") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "apple", "price": 1}, bson.M{"name": "banana", "price": 2}, @@ -261,7 +244,7 @@ func TestFindOneWithOptions(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) @@ -270,69 +253,14 @@ func TestFindOneWithOptions(t *testing.T) { } } -func TestParseError(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - gc := gomongo.NewClient(client) - ctx := context.Background() - - _, err := gc.Execute(ctx, "testdb", "db.users.find({ name: })") - require.Error(t, err) - - var parseErr *gomongo.ParseError - require.ErrorAs(t, err, &parseErr) -} - -func TestPlannedOperation(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - gc := gomongo.NewClient(client) - ctx := context.Background() - - // insertOne is a planned M2 operation - should return PlannedOperationError - _, err := gc.Execute(ctx, "testdb", "db.users.insertOne({ name: 'test' })") - require.Error(t, err) - - var plannedErr *gomongo.PlannedOperationError - require.ErrorAs(t, err, &plannedErr) - require.Equal(t, "insertOne()", plannedErr.Operation) -} - -func TestUnsupportedOperation(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - gc := gomongo.NewClient(client) - ctx := context.Background() - - // createSearchIndex is NOT in the registry - should return UnsupportedOperationError - _, err := gc.Execute(ctx, "testdb", `db.movies.createSearchIndex({ name: "default", definition: { mappings: { dynamic: true } } })`) - require.Error(t, err) - - var unsupportedErr *gomongo.UnsupportedOperationError - require.ErrorAs(t, err, &unsupportedErr) - require.Equal(t, "createSearchIndex()", unsupportedErr.Operation) -} - -func TestMethodRegistryStats(t *testing.T) { - total := gomongo.MethodRegistryStats() - - // Registry should contain M2 (10) + M3 (22) = 32 planned methods - require.Equal(t, 32, total, "expected 32 planned methods in registry (M2: 10, M3: 22)") - - // Log stats for visibility - t.Logf("Method Registry Stats: total=%d planned methods", total) -} - func TestFindWithFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_find_filter" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30, "active": true}, bson.M{"name": "bob", "age": 25, "active": false}, @@ -401,7 +329,7 @@ func TestFindWithFilter(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, tc.expectedCount, result.RowCount) @@ -413,12 +341,13 @@ func TestFindWithFilter(t *testing.T) { } func TestFindWithCursorModifications(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_find_cursor" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("items") + collection := client.Database(dbName).Collection("items") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "apple", "price": 1, "category": "fruit"}, bson.M{"name": "banana", "price": 2, "category": "fruit"}, @@ -539,7 +468,7 @@ func TestFindWithCursorModifications(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, tc.expectedCount, result.RowCount) @@ -550,134 +479,200 @@ func TestFindWithCursorModifications(t *testing.T) { } } -func TestCollectionAccessPatterns(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestFindWithProjectionArg(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_find_proj_arg" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Insert a document - collection := client.Database("testdb").Collection("my-collection") - _, err := collection.InsertOne(ctx, bson.M{"data": "test"}) + // Insert test data + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "age": 30, "city": "NYC"}, + bson.M{"name": "Bob", "age": 25, "city": "LA"}, + }) require.NoError(t, err) gc := gomongo.NewClient(client) - tests := []struct { - name string - statement string - }{ - {"dot access", "db.users.find()"}, - {"bracket double quote", `db["my-collection"].find()`}, - {"bracket single quote", `db['my-collection'].find()`}, - {"getCollection", `db.getCollection("my-collection").find()`}, - } + // find with projection as 2nd argument + result, err := gc.Execute(ctx, dbName, `db.users.find({}, { name: 1, _id: 0 })`) + require.NoError(t, err) + require.Equal(t, 2, result.RowCount) - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) - require.NoError(t, err) - require.NotNil(t, result) - }) + // Verify only 'name' field is returned + for _, row := range result.Rows { + require.Contains(t, row, "name") + require.NotContains(t, row, "age") + require.NotContains(t, row, "city") } } -func TestShowDatabases(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestFindWithHintOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_find_hint" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Create a database by inserting a document - _, err := client.Database("mydb").Collection("test").InsertOne(ctx, bson.M{"x": 1}) + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "age": 30}, + bson.M{"name": "Bob", "age": 25}, + }) + require.NoError(t, err) + + // Create index + _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ + Keys: bson.D{{Key: "name", Value: 1}}, + }) require.NoError(t, err) gc := gomongo.NewClient(client) - tests := []struct { - name string - statement string - }{ - {"show dbs", "show dbs"}, - {"show databases", "show databases"}, - } + // find with hint option (index name) + result, err := gc.Execute(ctx, dbName, `db.users.find({}, {}, { hint: "name_1" })`) + require.NoError(t, err) + require.Equal(t, 2, result.RowCount) +} - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "mydb", tc.statement) - require.NoError(t, err) - require.NotNil(t, result) - require.GreaterOrEqual(t, result.RowCount, 1) +func TestFindWithMaxMinOptions(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_find_maxmin" + defer testutil.CleanupDatabase(t, client, dbName) - // Check that mydb is in the result - require.True(t, slices.Contains(result.Rows, "mydb"), "expected 'mydb' in database list, got: %v", result.Rows) - }) - } + ctx := context.Background() + + coll := client.Database(dbName).Collection("items") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"price": 10}, + bson.M{"price": 20}, + bson.M{"price": 30}, + bson.M{"price": 40}, + bson.M{"price": 50}, + }) + require.NoError(t, err) + + // Create index on price field (required for min/max) + _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ + Keys: bson.D{{Key: "price", Value: 1}}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // find with min and max options (requires hint) + result, err := gc.Execute(ctx, dbName, `db.items.find({}, {}, { hint: { price: 1 }, min: { price: 20 }, max: { price: 40 } })`) + require.NoError(t, err) + // Should return items with price 20 and 30 (max is exclusive) + require.Equal(t, 2, result.RowCount) } -func TestShowCollections(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestFindWithMaxTimeMSOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_find_maxtime" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Create collections by inserting documents - _, err := client.Database("testdb").Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) - require.NoError(t, err) - _, err = client.Database("testdb").Collection("orders").InsertOne(ctx, bson.M{"item": "book"}) + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice"}, + bson.M{"name": "Bob"}, + }) require.NoError(t, err) gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", "show collections") + // find with maxTimeMS option + result, err := gc.Execute(ctx, dbName, `db.users.find({}, {}, { maxTimeMS: 5000 })`) require.NoError(t, err) - require.NotNil(t, result) require.Equal(t, 2, result.RowCount) +} - // Check that both collections are in the result - collectionSet := make(map[string]bool) - for _, row := range result.Rows { - collectionSet[row] = true - } - require.True(t, collectionSet["users"], "expected 'users' collection") - require.True(t, collectionSet["orders"], "expected 'orders' collection") +func TestFindOneWithProjectionAndOptions(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_findone_proj_opts" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "age": 30, "city": "NYC"}, + bson.M{"name": "Bob", "age": 25, "city": "LA"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // findOne with projection as 2nd argument + result, err := gc.Execute(ctx, dbName, `db.users.findOne({}, { name: 1, _id: 0 })`) + require.NoError(t, err) + require.Equal(t, 1, result.RowCount) + require.Contains(t, result.Rows[0], "name") + require.NotContains(t, result.Rows[0], "age") } -func TestGetCollectionNames(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestFindOneWithHintOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_findone_hint" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Create collections by inserting documents - _, err := client.Database("testdb").Collection("products").InsertOne(ctx, bson.M{"name": "widget"}) + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "age": 30}, + bson.M{"name": "Bob", "age": 25}, + }) require.NoError(t, err) - _, err = client.Database("testdb").Collection("categories").InsertOne(ctx, bson.M{"name": "electronics"}) + + // Create index + _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ + Keys: bson.D{{Key: "name", Value: 1}}, + }) require.NoError(t, err) gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", "db.getCollectionNames()") + // findOne with hint option (index name) + result, err := gc.Execute(ctx, dbName, `db.users.findOne({}, {}, { hint: "name_1" })`) require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 2, result.RowCount) + require.Equal(t, 1, result.RowCount) +} - // Check that both collections are in the result - collectionSet := make(map[string]bool) - for _, row := range result.Rows { - collectionSet[row] = true - } - require.True(t, collectionSet["products"], "expected 'products' collection") - require.True(t, collectionSet["categories"], "expected 'categories' collection") +func TestFindOneWithMaxTimeMSOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_findone_maxtime" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice"}, + bson.M{"name": "Bob"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // findOne with maxTimeMS option + result, err := gc.Execute(ctx, dbName, `db.users.findOne({}, {}, { maxTimeMS: 5000 })`) + require.NoError(t, err) + require.Equal(t, 1, result.RowCount) } func TestAggregateBasic(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_basic" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("items") + collection := client.Database(dbName).Collection("items") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "apple", "price": 1, "category": "fruit"}, bson.M{"name": "banana", "price": 2, "category": "fruit"}, @@ -777,7 +772,7 @@ func TestAggregateBasic(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, tc.expectedCount, result.RowCount) @@ -789,12 +784,13 @@ func TestAggregateBasic(t *testing.T) { } func TestAggregateGroup(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_group" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("sales") + collection := client.Database(dbName).Collection("sales") _, err := collection.InsertMany(ctx, []any{ bson.M{"item": "apple", "quantity": 10, "price": 1.5}, bson.M{"item": "banana", "quantity": 5, "price": 2.0}, @@ -855,7 +851,7 @@ func TestAggregateGroup(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, tc.expectedCount, result.RowCount) @@ -867,12 +863,13 @@ func TestAggregateGroup(t *testing.T) { } func TestAggregateCollectionAccess(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_coll_access" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("my-items") + collection := client.Database(dbName).Collection("my-items") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "test1"}, bson.M{"name": "test2"}, @@ -892,7 +889,7 @@ func TestAggregateCollectionAccess(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := gc.Execute(ctx, "testdb", tc.statement) + result, err := gc.Execute(ctx, dbName, tc.statement) require.NoError(t, err) require.NotNil(t, result) }) @@ -902,12 +899,13 @@ func TestAggregateCollectionAccess(t *testing.T) { // TestAggregateFilteredSubset tests the "Filtered Subset" example from MongoDB docs // https://www.mongodb.com/docs/manual/tutorial/aggregation-examples/filtered-subset/ func TestAggregateFilteredSubset(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_filtered" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("persons") + collection := client.Database(dbName).Collection("persons") _, err := collection.InsertMany(ctx, []any{ bson.M{ "person_id": "6392529400", @@ -971,7 +969,7 @@ func TestAggregateFilteredSubset(t *testing.T) { { $unset: ["_id", "vocation", "address"] } ])` - result, err := gc.Execute(ctx, "testdb", statement) + result, err := gc.Execute(ctx, dbName, statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 3, result.RowCount) @@ -992,12 +990,13 @@ func TestAggregateFilteredSubset(t *testing.T) { // TestAggregateGroupAndTotal tests the "Group and Total" example from MongoDB docs // https://www.mongodb.com/docs/manual/tutorial/aggregation-examples/group-and-total/ func TestAggregateGroupAndTotal(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_group_total" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("orders") + collection := client.Database(dbName).Collection("orders") _, err := collection.InsertMany(ctx, []any{ bson.M{ "customer_id": "elise_smith@myemail.com", @@ -1069,7 +1068,7 @@ func TestAggregateGroupAndTotal(t *testing.T) { { $unset: ["_id"] } ])` - result, err := gc.Execute(ctx, "testdb", statement) + result, err := gc.Execute(ctx, dbName, statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 3, result.RowCount) @@ -1087,12 +1086,13 @@ func TestAggregateGroupAndTotal(t *testing.T) { // TestAggregateUnwindArrays tests the "Unpack Arrays" example from MongoDB docs // https://www.mongodb.com/docs/manual/tutorial/aggregation-examples/unpack-arrays/ func TestAggregateUnwindArrays(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_unwind" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - collection := client.Database("testdb").Collection("orders") + collection := client.Database(dbName).Collection("orders") _, err := collection.InsertMany(ctx, []any{ bson.M{ "order_id": 6363763262239, @@ -1141,7 +1141,7 @@ func TestAggregateUnwindArrays(t *testing.T) { { $unset: ["_id"] } ])` - result, err := gc.Execute(ctx, "testdb", statement) + result, err := gc.Execute(ctx, dbName, statement) require.NoError(t, err) require.NotNil(t, result) // Should have: abc12345 (2x), def45678 (3x but all > 15), pqr88223 (1x), xyz11228 (1x) @@ -1157,13 +1157,14 @@ func TestAggregateUnwindArrays(t *testing.T) { // TestAggregateOneToOneJoin tests the "One-to-One Join" example from MongoDB docs // https://www.mongodb.com/docs/manual/tutorial/aggregation-examples/one-to-one-join/ func TestAggregateOneToOneJoin(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_join_1to1" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create products collection - productsCollection := client.Database("testdb").Collection("products") + productsCollection := client.Database(dbName).Collection("products") _, err := productsCollection.InsertMany(ctx, []any{ bson.M{ "id": "a1b2c3d4", @@ -1193,7 +1194,7 @@ func TestAggregateOneToOneJoin(t *testing.T) { require.NoError(t, err) // Create orders collection - ordersCollection := client.Database("testdb").Collection("orders") + ordersCollection := client.Database(dbName).Collection("orders") _, err = ordersCollection.InsertMany(ctx, []any{ bson.M{ "customer_id": "elise_smith@myemail.com", @@ -1246,7 +1247,7 @@ func TestAggregateOneToOneJoin(t *testing.T) { { $unset: ["_id", "product_id", "product_mapping"] } ])` - result, err := gc.Execute(ctx, "testdb", statement) + result, err := gc.Execute(ctx, dbName, statement) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 3, result.RowCount) // Only 2020 orders: elise, oranieri, jjones @@ -1261,13 +1262,14 @@ func TestAggregateOneToOneJoin(t *testing.T) { // TestAggregateMultiFieldJoin tests the "Multi-Field Join" example from MongoDB docs // https://www.mongodb.com/docs/manual/tutorial/aggregation-examples/multi-field-join/ func TestAggregateMultiFieldJoin(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_agg_join_multi" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create products collection - productsCollection := client.Database("testdb").Collection("products") + productsCollection := client.Database(dbName).Collection("products") _, err := productsCollection.InsertMany(ctx, []any{ bson.M{ "name": "Asus Laptop", @@ -1309,7 +1311,7 @@ func TestAggregateMultiFieldJoin(t *testing.T) { require.NoError(t, err) // Create orders collection - ordersCollection := client.Database("testdb").Collection("orders") + ordersCollection := client.Database(dbName).Collection("orders") _, err = ordersCollection.InsertMany(ctx, []any{ bson.M{ "customer_id": "elise_smith@myemail.com", @@ -1372,7 +1374,7 @@ func TestAggregateMultiFieldJoin(t *testing.T) { { $unset: ["_id"] } ])` - result, err := gc.Execute(ctx, "testdb", statement) + result, err := gc.Execute(ctx, dbName, statement) require.NoError(t, err) require.NotNil(t, result) // Should have: Asus Laptop Normal Display (2 orders), Morphy Richards (1 order) @@ -1385,155 +1387,91 @@ func TestAggregateMultiFieldJoin(t *testing.T) { require.NotContains(t, result.Rows[0], `"_id"`) } -func TestGetCollectionInfos(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestAggregateWithOptions(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_agg_options" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Create collections by inserting documents - _, err := client.Database("testdb").Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) - require.NoError(t, err) - _, err = client.Database("testdb").Collection("orders").InsertOne(ctx, bson.M{"item": "book"}) + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "age": 30}, + bson.M{"name": "Bob", "age": 25}, + }) require.NoError(t, err) gc := gomongo.NewClient(client) - // Test without filter - should return all collections - result, err := gc.Execute(ctx, "testdb", "db.getCollectionInfos()") + // aggregate with maxTimeMS option + result, err := gc.Execute(ctx, dbName, `db.users.aggregate([{ $match: { age: { $gt: 20 } } }], { maxTimeMS: 5000 })`) require.NoError(t, err) - require.NotNil(t, result) require.Equal(t, 2, result.RowCount) - - // Verify that results contain collection info structure - for _, row := range result.Rows { - require.Contains(t, row, `"name"`) - require.Contains(t, row, `"type"`) - } -} - -func TestGetCollectionInfosWithFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - // Create collections by inserting documents - _, err := client.Database("testdb").Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) - require.NoError(t, err) - _, err = client.Database("testdb").Collection("orders").InsertOne(ctx, bson.M{"item": "book"}) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // Test with filter - should return only matching collection - result, err := gc.Execute(ctx, "testdb", `db.getCollectionInfos({ name: "users" })`) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 1, result.RowCount) - - // Verify that the returned collection is "users" - require.Contains(t, result.Rows[0], `"name": "users"`) - require.Contains(t, result.Rows[0], `"type": "collection"`) } -func TestGetCollectionInfosEmptyResult(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestAggregateWithHintOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_agg_hint" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Create a collection - _, err := client.Database("testdb").Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // Test with filter that matches no collections - result, err := gc.Execute(ctx, "testdb", `db.getCollectionInfos({ name: "nonexistent" })`) + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "age": 30}, + bson.M{"name": "Bob", "age": 25}, + }) require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 0, result.RowCount) - require.Empty(t, result.Rows) -} - -func TestGetCollectionInfosNameOnly(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - // Create a collection - _, err := client.Database("testdb").Collection("users").InsertOne(ctx, bson.M{"name": "test"}) + // Create index on age field + _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ + Keys: bson.D{{Key: "age", Value: 1}}, + }) require.NoError(t, err) gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", `db.getCollectionInfos({}, { nameOnly: true })`) - require.NoError(t, err) - require.GreaterOrEqual(t, result.RowCount, 1) - - // With nameOnly: true, the result should contain "name" field - require.Contains(t, result.Rows[0], `"name"`) -} - -func TestGetCollectionInfosAuthorizedCollections(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - // Create a collection - _, err := client.Database("testdb").Collection("users").InsertOne(ctx, bson.M{"name": "test"}) + // aggregate with hint option (index name) + result, err := gc.Execute(ctx, dbName, `db.users.aggregate([{ $match: { age: { $gt: 20 } } }], { hint: "age_1" })`) require.NoError(t, err) + require.Equal(t, 2, result.RowCount) - gc := gomongo.NewClient(client) - - result, err := gc.Execute(ctx, "testdb", `db.getCollectionInfos({}, { authorizedCollections: true })`) + // aggregate with hint option (index spec) + result, err = gc.Execute(ctx, dbName, `db.users.aggregate([{ $match: { age: { $gt: 20 } } }], { hint: { age: 1 } })`) require.NoError(t, err) - require.GreaterOrEqual(t, result.RowCount, 1) + require.Equal(t, 2, result.RowCount) } -func TestGetCollectionInfosUnsupportedOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestAggregateTooManyArguments(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_agg_too_many_args" + defer testutil.CleanupDatabase(t, client, dbName) - gc := gomongo.NewClient(client) ctx := context.Background() - _, err := gc.Execute(ctx, "testdb", `db.getCollectionInfos({}, { unknownOption: true })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "getCollectionInfos()", optErr.Method) -} - -func TestGetCollectionInfosTooManyArgs(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - gc := gomongo.NewClient(client) - ctx := context.Background() - _, err := gc.Execute(ctx, "testdb", `db.getCollectionInfos({}, {}, {})`) + _, err := gc.Execute(ctx, dbName, `db.users.aggregate([], {}, "extra")`) require.Error(t, err) - require.Contains(t, err.Error(), "takes at most 2 arguments") + require.Contains(t, err.Error(), "aggregate() takes at most 2 arguments") } func TestGetIndexes(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_get_indexes" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with a document (this creates the default _id index) - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertOne(ctx, bson.M{"name": "alice", "email": "alice@example.com"}) require.NoError(t, err) gc := gomongo.NewClient(client) // Test getIndexes - should return at least the _id index - result, err := gc.Execute(ctx, "testdb", "db.users.getIndexes()") + result, err := gc.Execute(ctx, dbName, "db.users.getIndexes()") require.NoError(t, err) require.NotNil(t, result) require.GreaterOrEqual(t, result.RowCount, 1) @@ -1550,13 +1488,14 @@ func TestGetIndexes(t *testing.T) { } func TestGetIndexesWithCustomIndex(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_indexes_custom" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection and add a custom index - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertOne(ctx, bson.M{"name": "alice", "email": "alice@example.com"}) require.NoError(t, err) @@ -1568,7 +1507,7 @@ func TestGetIndexesWithCustomIndex(t *testing.T) { gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", "db.users.getIndexes()") + result, err := gc.Execute(ctx, dbName, "db.users.getIndexes()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 2, result.RowCount) // _id index + email index @@ -1589,20 +1528,21 @@ func TestGetIndexesWithCustomIndex(t *testing.T) { } func TestGetIndexesBracketNotation(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_indexes_bracket" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with hyphenated name - collection := client.Database("testdb").Collection("user-logs") + collection := client.Database(dbName).Collection("user-logs") _, err := collection.InsertOne(ctx, bson.M{"message": "test"}) require.NoError(t, err) gc := gomongo.NewClient(client) // Test with bracket notation - result, err := gc.Execute(ctx, "testdb", `db["user-logs"].getIndexes()`) + result, err := gc.Execute(ctx, dbName, `db["user-logs"].getIndexes()`) require.NoError(t, err) require.NotNil(t, result) require.GreaterOrEqual(t, result.RowCount, 1) @@ -1612,13 +1552,14 @@ func TestGetIndexesBracketNotation(t *testing.T) { } func TestCountDocuments(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_count" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30}, bson.M{"name": "bob", "age": 25}, @@ -1629,7 +1570,7 @@ func TestCountDocuments(t *testing.T) { gc := gomongo.NewClient(client) // Test countDocuments without filter - result, err := gc.Execute(ctx, "testdb", "db.users.countDocuments()") + result, err := gc.Execute(ctx, dbName, "db.users.countDocuments()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) @@ -1637,13 +1578,14 @@ func TestCountDocuments(t *testing.T) { } func TestCountDocumentsWithFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_count_filter" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30, "status": "active"}, bson.M{"name": "bob", "age": 25, "status": "inactive"}, @@ -1655,28 +1597,29 @@ func TestCountDocumentsWithFilter(t *testing.T) { gc := gomongo.NewClient(client) // Test countDocuments with filter - result, err := gc.Execute(ctx, "testdb", `db.users.countDocuments({ status: "active" })`) + result, err := gc.Execute(ctx, dbName, `db.users.countDocuments({ status: "active" })`) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) require.Equal(t, "3", result.Rows[0]) // Test with comparison operator - result, err = gc.Execute(ctx, "testdb", `db.users.countDocuments({ age: { $gte: 30 } })`) + result, err = gc.Execute(ctx, dbName, `db.users.countDocuments({ age: { $gte: 30 } })`) require.NoError(t, err) require.Equal(t, "2", result.Rows[0]) } func TestCountDocumentsEmptyCollection(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_count_empty" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) // Test countDocuments on empty/non-existent collection - result, err := gc.Execute(ctx, "testdb", "db.users.countDocuments()") + result, err := gc.Execute(ctx, dbName, "db.users.countDocuments()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) @@ -1684,13 +1627,14 @@ func TestCountDocumentsEmptyCollection(t *testing.T) { } func TestCountDocumentsWithEmptyFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_count_empty_filter" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("items") + collection := client.Database(dbName).Collection("items") _, err := collection.InsertMany(ctx, []any{ bson.M{"item": "a"}, bson.M{"item": "b"}, @@ -1700,20 +1644,21 @@ func TestCountDocumentsWithEmptyFilter(t *testing.T) { gc := gomongo.NewClient(client) // Test countDocuments with empty filter {} - result, err := gc.Execute(ctx, "testdb", "db.items.countDocuments({})") + result, err := gc.Execute(ctx, dbName, "db.items.countDocuments({})") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, "2", result.Rows[0]) } func TestCountDocumentsWithOptions(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_count_options" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "age": 30}, bson.M{"name": "bob", "age": 25}, @@ -1726,29 +1671,30 @@ func TestCountDocumentsWithOptions(t *testing.T) { gc := gomongo.NewClient(client) // Test with limit option - result, err := gc.Execute(ctx, "testdb", `db.users.countDocuments({}, { limit: 3 })`) + result, err := gc.Execute(ctx, dbName, `db.users.countDocuments({}, { limit: 3 })`) require.NoError(t, err) require.Equal(t, "3", result.Rows[0]) // Test with skip option - result, err = gc.Execute(ctx, "testdb", `db.users.countDocuments({}, { skip: 2 })`) + result, err = gc.Execute(ctx, dbName, `db.users.countDocuments({}, { skip: 2 })`) require.NoError(t, err) require.Equal(t, "3", result.Rows[0]) // Test with both limit and skip - result, err = gc.Execute(ctx, "testdb", `db.users.countDocuments({}, { skip: 1, limit: 2 })`) + result, err = gc.Execute(ctx, dbName, `db.users.countDocuments({}, { skip: 1, limit: 2 })`) require.NoError(t, err) require.Equal(t, "2", result.Rows[0]) } func TestCountDocumentsWithHint(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_count_hint" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents and an index - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "status": "active"}, bson.M{"name": "bob", "status": "inactive"}, @@ -1765,51 +1711,74 @@ func TestCountDocumentsWithHint(t *testing.T) { gc := gomongo.NewClient(client) // Test with hint using index name - result, err := gc.Execute(ctx, "testdb", `db.users.countDocuments({ status: "active" }, { hint: "status_1" })`) + result, err := gc.Execute(ctx, dbName, `db.users.countDocuments({ status: "active" }, { hint: "status_1" })`) require.NoError(t, err) require.Equal(t, "2", result.Rows[0]) // Test with hint using index specification document - result, err = gc.Execute(ctx, "testdb", `db.users.countDocuments({ status: "active" }, { hint: { status: 1 } })`) + result, err = gc.Execute(ctx, dbName, `db.users.countDocuments({ status: "active" }, { hint: { status: 1 } })`) require.NoError(t, err) require.Equal(t, "2", result.Rows[0]) } -func TestEstimatedDocumentCount(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestCountDocumentsMaxTimeMS(t *testing.T) { + dbName := "testdb_count_maxtime" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - // Create a collection with documents - collection := client.Database("testdb").Collection("users") - _, err := collection.InsertMany(ctx, []any{ - bson.M{"name": "alice"}, - bson.M{"name": "bob"}, - bson.M{"name": "charlie"}, + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice"}, + bson.M{"name": "Bob"}, }) require.NoError(t, err) gc := gomongo.NewClient(client) - // Test estimatedDocumentCount - result, err := gc.Execute(ctx, "testdb", "db.users.estimatedDocumentCount()") + result, err := gc.Execute(ctx, dbName, `db.users.countDocuments({}, { maxTimeMS: 5000 })`) require.NoError(t, err) - require.NotNil(t, result) + require.Equal(t, "2", result.Rows[0]) +} + +func TestEstimatedDocumentCount(t *testing.T) { + dbName := "testdb_est_count" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create a collection with documents + collection := client.Database(dbName).Collection("users") + _, err := collection.InsertMany(ctx, []any{ + bson.M{"name": "alice"}, + bson.M{"name": "bob"}, + bson.M{"name": "charlie"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Test estimatedDocumentCount + result, err := gc.Execute(ctx, dbName, "db.users.estimatedDocumentCount()") + require.NoError(t, err) + require.NotNil(t, result) require.Equal(t, 1, result.RowCount) require.Equal(t, "3", result.Rows[0]) } func TestEstimatedDocumentCountEmptyCollection(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_est_count_empty" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) // Test estimatedDocumentCount on empty/non-existent collection - result, err := gc.Execute(ctx, "testdb", "db.users.estimatedDocumentCount()") + result, err := gc.Execute(ctx, dbName, "db.users.estimatedDocumentCount()") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 1, result.RowCount) @@ -1817,13 +1786,14 @@ func TestEstimatedDocumentCountEmptyCollection(t *testing.T) { } func TestEstimatedDocumentCountWithEmptyOptions(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_est_count_opts" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("items") + collection := client.Database(dbName).Collection("items") _, err := collection.InsertMany(ctx, []any{ bson.M{"item": "a"}, bson.M{"item": "b"}, @@ -1833,20 +1803,43 @@ func TestEstimatedDocumentCountWithEmptyOptions(t *testing.T) { gc := gomongo.NewClient(client) // Test estimatedDocumentCount with empty options {} - result, err := gc.Execute(ctx, "testdb", "db.items.estimatedDocumentCount({})") + result, err := gc.Execute(ctx, dbName, "db.items.estimatedDocumentCount({})") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, "2", result.Rows[0]) } +func TestEstimatedDocumentCountMaxTimeMS(t *testing.T) { + dbName := "testdb_est_count_maxtime" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + coll := client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice"}, + bson.M{"name": "Bob"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + result, err := gc.Execute(ctx, dbName, `db.users.estimatedDocumentCount({ maxTimeMS: 5000 })`) + require.NoError(t, err) + require.Equal(t, 1, result.RowCount) + require.Equal(t, "2", result.Rows[0]) +} + func TestDistinct(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_distinct" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("users") + collection := client.Database(dbName).Collection("users") _, err := collection.InsertMany(ctx, []any{ bson.M{"name": "alice", "status": "active"}, bson.M{"name": "bob", "status": "inactive"}, @@ -1858,7 +1851,7 @@ func TestDistinct(t *testing.T) { gc := gomongo.NewClient(client) // Test distinct on status field - result, err := gc.Execute(ctx, "testdb", `db.users.distinct("status")`) + result, err := gc.Execute(ctx, dbName, `db.users.distinct("status")`) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 2, result.RowCount) @@ -1872,13 +1865,14 @@ func TestDistinct(t *testing.T) { } func TestDistinctWithFilter(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_distinct_filter" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with documents - collection := client.Database("testdb").Collection("products") + collection := client.Database(dbName).Collection("products") _, err := collection.InsertMany(ctx, []any{ bson.M{"category": "electronics", "brand": "Apple", "price": 999}, bson.M{"category": "electronics", "brand": "Samsung", "price": 799}, @@ -1891,7 +1885,7 @@ func TestDistinctWithFilter(t *testing.T) { gc := gomongo.NewClient(client) // Test distinct with filter - result, err := gc.Execute(ctx, "testdb", `db.products.distinct("brand", { category: "electronics" })`) + result, err := gc.Execute(ctx, dbName, `db.products.distinct("brand", { category: "electronics" })`) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 2, result.RowCount) @@ -1908,15 +1902,16 @@ func TestDistinctWithFilter(t *testing.T) { } func TestDistinctEmptyCollection(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_distinct_empty" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) // Test distinct on empty/non-existent collection - result, err := gc.Execute(ctx, "testdb", `db.users.distinct("status")`) + result, err := gc.Execute(ctx, dbName, `db.users.distinct("status")`) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 0, result.RowCount) @@ -1924,13 +1919,14 @@ func TestDistinctEmptyCollection(t *testing.T) { } func TestDistinctBracketNotation(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_distinct_bracket" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with hyphenated name - collection := client.Database("testdb").Collection("user-logs") + collection := client.Database(dbName).Collection("user-logs") _, err := collection.InsertMany(ctx, []any{ bson.M{"level": "info"}, bson.M{"level": "warn"}, @@ -1942,20 +1938,21 @@ func TestDistinctBracketNotation(t *testing.T) { gc := gomongo.NewClient(client) // Test with bracket notation - result, err := gc.Execute(ctx, "testdb", `db["user-logs"].distinct("level")`) + result, err := gc.Execute(ctx, dbName, `db["user-logs"].distinct("level")`) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 3, result.RowCount) } func TestDistinctNumericValues(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_distinct_numeric" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() // Create a collection with numeric values - collection := client.Database("testdb").Collection("scores") + collection := client.Database(dbName).Collection("scores") _, err := collection.InsertMany(ctx, []any{ bson.M{"score": 100}, bson.M{"score": 85}, @@ -1968,361 +1965,20 @@ func TestDistinctNumericValues(t *testing.T) { gc := gomongo.NewClient(client) // Test distinct on numeric field - result, err := gc.Execute(ctx, "testdb", `db.scores.distinct("score")`) + result, err := gc.Execute(ctx, dbName, `db.scores.distinct("score")`) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 3, result.RowCount) // 100, 85, 90 } -func TestCursorCountUnsupported(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) - - // cursor.count() is not in the planned registry, should return UnsupportedOperationError - _, err := gc.Execute(ctx, "testdb", "db.users.find().count()") - require.Error(t, err) - - var unsupportedErr *gomongo.UnsupportedOperationError - require.ErrorAs(t, err, &unsupportedErr) - require.Equal(t, "count()", unsupportedErr.Operation) -} - -func TestUnsupportedOptionError(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) - - // find() with unsupported option 'collation' - _, err := gc.Execute(ctx, "testdb", `db.users.find({}, {}, { collation: { locale: "en" } })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "find()", optErr.Method) - require.Equal(t, "collation", optErr.Option) -} - -func TestFindWithProjectionArg(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - // Insert test data - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice", "age": 30, "city": "NYC"}, - bson.M{"name": "Bob", "age": 25, "city": "LA"}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // find with projection as 2nd argument - result, err := gc.Execute(ctx, "testdb", `db.users.find({}, { name: 1, _id: 0 })`) - require.NoError(t, err) - require.Equal(t, 2, result.RowCount) - - // Verify only 'name' field is returned - for _, row := range result.Rows { - require.Contains(t, row, "name") - require.NotContains(t, row, "age") - require.NotContains(t, row, "city") - } -} - -func TestFindWithHintOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice", "age": 30}, - bson.M{"name": "Bob", "age": 25}, - }) - require.NoError(t, err) - - // Create index - _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "name", Value: 1}}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // find with hint option (index name) - result, err := gc.Execute(ctx, "testdb", `db.users.find({}, {}, { hint: "name_1" })`) - require.NoError(t, err) - require.Equal(t, 2, result.RowCount) -} - -func TestFindWithMaxMinOptions(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("items") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"price": 10}, - bson.M{"price": 20}, - bson.M{"price": 30}, - bson.M{"price": 40}, - bson.M{"price": 50}, - }) - require.NoError(t, err) - - // Create index on price field (required for min/max) - _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "price", Value: 1}}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // find with min and max options (requires hint) - result, err := gc.Execute(ctx, "testdb", `db.items.find({}, {}, { hint: { price: 1 }, min: { price: 20 }, max: { price: 40 } })`) - require.NoError(t, err) - // Should return items with price 20 and 30 (max is exclusive) - require.Equal(t, 2, result.RowCount) -} - -func TestFindWithMaxTimeMSOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice"}, - bson.M{"name": "Bob"}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // find with maxTimeMS option - result, err := gc.Execute(ctx, "testdb", `db.users.find({}, {}, { maxTimeMS: 5000 })`) - require.NoError(t, err) - require.Equal(t, 2, result.RowCount) -} - -func TestFindOneWithProjectionAndOptions(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice", "age": 30, "city": "NYC"}, - bson.M{"name": "Bob", "age": 25, "city": "LA"}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // findOne with projection as 2nd argument - result, err := gc.Execute(ctx, "testdb", `db.users.findOne({}, { name: 1, _id: 0 })`) - require.NoError(t, err) - require.Equal(t, 1, result.RowCount) - require.Contains(t, result.Rows[0], "name") - require.NotContains(t, result.Rows[0], "age") -} - -func TestFindOneUnsupportedOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) - - _, err := gc.Execute(ctx, "testdb", `db.users.findOne({}, {}, { collation: { locale: "en" } })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "findOne()", optErr.Method) - require.Equal(t, "collation", optErr.Option) -} - -func TestFindOneWithHintOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice", "age": 30}, - bson.M{"name": "Bob", "age": 25}, - }) - require.NoError(t, err) - - // Create index - _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "name", Value: 1}}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // findOne with hint option (index name) - result, err := gc.Execute(ctx, "testdb", `db.users.findOne({}, {}, { hint: "name_1" })`) - require.NoError(t, err) - require.Equal(t, 1, result.RowCount) -} - -func TestFindOneWithMaxTimeMSOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice"}, - bson.M{"name": "Bob"}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // findOne with maxTimeMS option - result, err := gc.Execute(ctx, "testdb", `db.users.findOne({}, {}, { maxTimeMS: 5000 })`) - require.NoError(t, err) - require.Equal(t, 1, result.RowCount) -} - -func TestAggregateWithOptions(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice", "age": 30}, - bson.M{"name": "Bob", "age": 25}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // aggregate with maxTimeMS option - result, err := gc.Execute(ctx, "testdb", `db.users.aggregate([{ $match: { age: { $gt: 20 } } }], { maxTimeMS: 5000 })`) - require.NoError(t, err) - require.Equal(t, 2, result.RowCount) -} - -func TestAggregateWithHintOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice", "age": 30}, - bson.M{"name": "Bob", "age": 25}, - }) - require.NoError(t, err) - - // Create index on age field - _, err = coll.Indexes().CreateOne(ctx, mongo.IndexModel{ - Keys: bson.D{{Key: "age", Value: 1}}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - // aggregate with hint option (index name) - result, err := gc.Execute(ctx, "testdb", `db.users.aggregate([{ $match: { age: { $gt: 20 } } }], { hint: "age_1" })`) - require.NoError(t, err) - require.Equal(t, 2, result.RowCount) - - // aggregate with hint option (index spec) - result, err = gc.Execute(ctx, "testdb", `db.users.aggregate([{ $match: { age: { $gt: 20 } } }], { hint: { age: 1 } })`) - require.NoError(t, err) - require.Equal(t, 2, result.RowCount) -} - -func TestAggregateUnsupportedOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) - - _, err := gc.Execute(ctx, "testdb", `db.users.aggregate([], { allowDiskUse: true })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "aggregate()", optErr.Method) - require.Equal(t, "allowDiskUse", optErr.Option) -} - -func TestAggregateTooManyArguments(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) - - _, err := gc.Execute(ctx, "testdb", `db.users.aggregate([], {}, "extra")`) - require.Error(t, err) - require.Contains(t, err.Error(), "aggregate() takes at most 2 arguments") -} - -func TestCountDocumentsMaxTimeMS(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice"}, - bson.M{"name": "Bob"}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - result, err := gc.Execute(ctx, "testdb", `db.users.countDocuments({}, { maxTimeMS: 5000 })`) - require.NoError(t, err) - require.Equal(t, "2", result.Rows[0]) -} - -func TestCountDocumentsUnsupportedOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) - - _, err := gc.Execute(ctx, "testdb", `db.users.countDocuments({}, { collation: { locale: "en" } })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "countDocuments()", optErr.Method) -} - func TestDistinctMaxTimeMS(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + dbName := "testdb_distinct_maxtime" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() - coll := client.Database("testdb").Collection("users") + coll := client.Database(dbName).Collection("users") _, err := coll.InsertMany(ctx, []any{ bson.M{"name": "Alice", "city": "NYC"}, bson.M{"name": "Bob", "city": "LA"}, @@ -2332,70 +1988,39 @@ func TestDistinctMaxTimeMS(t *testing.T) { gc := gomongo.NewClient(client) - result, err := gc.Execute(ctx, "testdb", `db.users.distinct("city", {}, { maxTimeMS: 5000 })`) + result, err := gc.Execute(ctx, dbName, `db.users.distinct("city", {}, { maxTimeMS: 5000 })`) require.NoError(t, err) require.Equal(t, 2, result.RowCount) } -func TestDistinctUnsupportedOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() +func TestCursorCountUnsupported(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_cursor_count" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) - _, err := gc.Execute(ctx, "testdb", `db.users.distinct("city", {}, { collation: { locale: "en" } })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "distinct()", optErr.Method) -} - -func TestEstimatedDocumentCountMaxTimeMS(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - coll := client.Database("testdb").Collection("users") - _, err := coll.InsertMany(ctx, []any{ - bson.M{"name": "Alice"}, - bson.M{"name": "Bob"}, - }) - require.NoError(t, err) - - gc := gomongo.NewClient(client) - - result, err := gc.Execute(ctx, "testdb", `db.users.estimatedDocumentCount({ maxTimeMS: 5000 })`) - require.NoError(t, err) - require.Equal(t, 1, result.RowCount) - require.Equal(t, "2", result.Rows[0]) -} - -func TestEstimatedDocumentCountUnsupportedOption(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() - - ctx := context.Background() - - gc := gomongo.NewClient(client) + // cursor.count() is not in the planned registry, should return UnsupportedOperationError + _, err := gc.Execute(ctx, dbName, "db.users.find().count()") + require.Error(t, err) - _, err := gc.Execute(ctx, "testdb", `db.users.estimatedDocumentCount({ comment: "test" })`) - var optErr *gomongo.UnsupportedOptionError - require.ErrorAs(t, err, &optErr) - require.Equal(t, "estimatedDocumentCount()", optErr.Method) - require.Equal(t, "comment", optErr.Option) + var unsupportedErr *gomongo.UnsupportedOperationError + require.ErrorAs(t, err, &unsupportedErr) + require.Equal(t, "count()", unsupportedErr.Operation) } func TestCursorHintMethod(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_cursor_hint" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) - coll := client.Database("testdb").Collection("users") + coll := client.Database(dbName).Collection("users") _, err := coll.InsertMany(ctx, []any{ bson.M{"name": "Alice", "age": 30}, bson.M{"name": "Bob", "age": 25}, @@ -2409,20 +2034,21 @@ func TestCursorHintMethod(t *testing.T) { require.NoError(t, err) // Use hint() cursor method with string - result, err := gc.Execute(ctx, "testdb", `db.users.find({}).hint("name_1")`) + result, err := gc.Execute(ctx, dbName, `db.users.find({}).hint("name_1")`) require.NoError(t, err) require.Equal(t, 2, result.RowCount) } func TestCursorHintMethodWithDocument(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_cursor_hint_doc" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) - coll := client.Database("testdb").Collection("users") + coll := client.Database(dbName).Collection("users") _, err := coll.InsertMany(ctx, []any{ bson.M{"name": "Alice", "age": 30}, }) @@ -2434,20 +2060,21 @@ func TestCursorHintMethodWithDocument(t *testing.T) { require.NoError(t, err) // Use hint() cursor method with document - result, err := gc.Execute(ctx, "testdb", `db.users.find({}).hint({ name: 1 })`) + result, err := gc.Execute(ctx, dbName, `db.users.find({}).hint({ name: 1 })`) require.NoError(t, err) require.Equal(t, 1, result.RowCount) } func TestCursorMaxMethod(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_cursor_max" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) - coll := client.Database("testdb").Collection("users") + coll := client.Database(dbName).Collection("users") _, err := coll.InsertMany(ctx, []any{ bson.M{"name": "Alice", "age": 30}, bson.M{"name": "Bob", "age": 25}, @@ -2462,21 +2089,22 @@ func TestCursorMaxMethod(t *testing.T) { require.NoError(t, err) // Use max() cursor method - returns documents with age < 30 - result, err := gc.Execute(ctx, "testdb", `db.users.find({}).hint({ age: 1 }).max({ age: 30 })`) + result, err := gc.Execute(ctx, dbName, `db.users.find({}).hint({ age: 1 }).max({ age: 30 })`) require.NoError(t, err) require.Equal(t, 1, result.RowCount) require.Contains(t, result.Rows[0], `"Bob"`) } func TestCursorMinMethod(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_cursor_min" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) - coll := client.Database("testdb").Collection("users") + coll := client.Database(dbName).Collection("users") _, err := coll.InsertMany(ctx, []any{ bson.M{"name": "Alice", "age": 30}, bson.M{"name": "Bob", "age": 25}, @@ -2491,20 +2119,21 @@ func TestCursorMinMethod(t *testing.T) { require.NoError(t, err) // Use min() cursor method - returns documents with age >= 30 - result, err := gc.Execute(ctx, "testdb", `db.users.find({}).hint({ age: 1 }).min({ age: 30 })`) + result, err := gc.Execute(ctx, dbName, `db.users.find({}).hint({ age: 1 }).min({ age: 30 })`) require.NoError(t, err) require.Equal(t, 2, result.RowCount) } func TestCursorMinMaxCombined(t *testing.T) { - client, cleanup := setupTestContainer(t) - defer cleanup() + client := testutil.GetClient(t) + dbName := "testdb_cursor_minmax" + defer testutil.CleanupDatabase(t, client, dbName) ctx := context.Background() gc := gomongo.NewClient(client) - coll := client.Database("testdb").Collection("users") + coll := client.Database(dbName).Collection("users") _, err := coll.InsertMany(ctx, []any{ bson.M{"name": "Alice", "age": 30}, bson.M{"name": "Bob", "age": 25}, @@ -2520,7 +2149,7 @@ func TestCursorMinMaxCombined(t *testing.T) { require.NoError(t, err) // Use min() and max() together - returns documents with 30 <= age < 40 - result, err := gc.Execute(ctx, "testdb", `db.users.find({}).hint({ age: 1 }).min({ age: 30 }).max({ age: 40 })`) + result, err := gc.Execute(ctx, dbName, `db.users.find({}).hint({ age: 1 }).min({ age: 30 }).max({ age: 40 })`) require.NoError(t, err) require.Equal(t, 2, result.RowCount) } diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..c868d0e --- /dev/null +++ b/database_test.go @@ -0,0 +1,278 @@ +package gomongo_test + +import ( + "context" + "slices" + "testing" + + "github.com/bytebase/gomongo" + "github.com/bytebase/gomongo/internal/testutil" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestShowDatabases(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_show_dbs" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create a database by inserting a document + _, err := client.Database(dbName).Collection("test").InsertOne(ctx, bson.M{"x": 1}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + tests := []struct { + name string + statement string + }{ + {"show dbs", "show dbs"}, + {"show databases", "show databases"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := gc.Execute(ctx, dbName, tc.statement) + require.NoError(t, err) + require.NotNil(t, result) + require.GreaterOrEqual(t, result.RowCount, 1) + + // Check that dbName is in the result + require.True(t, slices.Contains(result.Rows, dbName), "expected '%s' in database list, got: %v", dbName, result.Rows) + }) + } +} + +func TestShowCollections(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_show_colls" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create collections by inserting documents + _, err := client.Database(dbName).Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) + require.NoError(t, err) + _, err = client.Database(dbName).Collection("orders").InsertOne(ctx, bson.M{"item": "book"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + result, err := gc.Execute(ctx, dbName, "show collections") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 2, result.RowCount) + + // Check that both collections are in the result + collectionSet := make(map[string]bool) + for _, row := range result.Rows { + collectionSet[row] = true + } + require.True(t, collectionSet["users"], "expected 'users' collection") + require.True(t, collectionSet["orders"], "expected 'orders' collection") +} + +func TestGetCollectionNames(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_get_coll_names" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create collections by inserting documents + _, err := client.Database(dbName).Collection("products").InsertOne(ctx, bson.M{"name": "widget"}) + require.NoError(t, err) + _, err = client.Database(dbName).Collection("categories").InsertOne(ctx, bson.M{"name": "electronics"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + result, err := gc.Execute(ctx, dbName, "db.getCollectionNames()") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 2, result.RowCount) + + // Check that both collections are in the result + collectionSet := make(map[string]bool) + for _, row := range result.Rows { + collectionSet[row] = true + } + require.True(t, collectionSet["products"], "expected 'products' collection") + require.True(t, collectionSet["categories"], "expected 'categories' collection") +} + +func TestGetCollectionInfos(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_get_coll_infos" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create collections by inserting documents + _, err := client.Database(dbName).Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) + require.NoError(t, err) + _, err = client.Database(dbName).Collection("orders").InsertOne(ctx, bson.M{"item": "book"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Test without filter - should return all collections + result, err := gc.Execute(ctx, dbName, "db.getCollectionInfos()") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 2, result.RowCount) + + // Verify that results contain collection info structure + for _, row := range result.Rows { + require.Contains(t, row, `"name"`) + require.Contains(t, row, `"type"`) + } +} + +func TestGetCollectionInfosWithFilter(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_infos_filter" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create collections by inserting documents + _, err := client.Database(dbName).Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) + require.NoError(t, err) + _, err = client.Database(dbName).Collection("orders").InsertOne(ctx, bson.M{"item": "book"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Test with filter - should return only matching collection + result, err := gc.Execute(ctx, dbName, `db.getCollectionInfos({ name: "users" })`) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.RowCount) + + // Verify that the returned collection is "users" + require.Contains(t, result.Rows[0], `"name": "users"`) + require.Contains(t, result.Rows[0], `"type": "collection"`) +} + +func TestGetCollectionInfosEmptyResult(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_infos_empty" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create a collection + _, err := client.Database(dbName).Collection("users").InsertOne(ctx, bson.M{"name": "alice"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + // Test with filter that matches no collections + result, err := gc.Execute(ctx, dbName, `db.getCollectionInfos({ name: "nonexistent" })`) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 0, result.RowCount) + require.Empty(t, result.Rows) +} + +func TestGetCollectionInfosNameOnly(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_infos_nameonly" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create a collection + _, err := client.Database(dbName).Collection("users").InsertOne(ctx, bson.M{"name": "test"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + result, err := gc.Execute(ctx, dbName, `db.getCollectionInfos({}, { nameOnly: true })`) + require.NoError(t, err) + require.GreaterOrEqual(t, result.RowCount, 1) + + // With nameOnly: true, the result should contain "name" field + require.Contains(t, result.Rows[0], `"name"`) +} + +func TestGetCollectionInfosAuthorizedCollections(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_infos_auth" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Create a collection + _, err := client.Database(dbName).Collection("users").InsertOne(ctx, bson.M{"name": "test"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + result, err := gc.Execute(ctx, dbName, `db.getCollectionInfos({}, { authorizedCollections: true })`) + require.NoError(t, err) + require.GreaterOrEqual(t, result.RowCount, 1) +} + +func TestGetCollectionInfosUnsupportedOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_infos_unsup" + defer testutil.CleanupDatabase(t, client, dbName) + + gc := gomongo.NewClient(client) + ctx := context.Background() + + _, err := gc.Execute(ctx, dbName, `db.getCollectionInfos({}, { unknownOption: true })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "getCollectionInfos()", optErr.Method) +} + +func TestGetCollectionInfosTooManyArgs(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_infos_args" + defer testutil.CleanupDatabase(t, client, dbName) + + gc := gomongo.NewClient(client) + ctx := context.Background() + + _, err := gc.Execute(ctx, dbName, `db.getCollectionInfos({}, {}, {})`) + require.Error(t, err) + require.Contains(t, err.Error(), "takes at most 2 arguments") +} + +func TestCollectionAccessPatterns(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_coll_access" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + // Insert a document + collection := client.Database(dbName).Collection("my-collection") + _, err := collection.InsertOne(ctx, bson.M{"data": "test"}) + require.NoError(t, err) + + gc := gomongo.NewClient(client) + + tests := []struct { + name string + statement string + }{ + {"dot access", "db.users.find()"}, + {"bracket double quote", `db["my-collection"].find()`}, + {"bracket single quote", `db['my-collection'].find()`}, + {"getCollection", `db.getCollection("my-collection").find()`}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := gc.Execute(ctx, dbName, tc.statement) + require.NoError(t, err) + require.NotNil(t, result) + }) + } +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..1733678 --- /dev/null +++ b/error_test.go @@ -0,0 +1,164 @@ +package gomongo_test + +import ( + "context" + "testing" + + "github.com/bytebase/gomongo" + "github.com/bytebase/gomongo/internal/testutil" + "github.com/stretchr/testify/require" +) + +func TestParseError(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_parse_error" + defer testutil.CleanupDatabase(t, client, dbName) + + gc := gomongo.NewClient(client) + ctx := context.Background() + + _, err := gc.Execute(ctx, dbName, "db.users.find({ name: })") + require.Error(t, err) + + var parseErr *gomongo.ParseError + require.ErrorAs(t, err, &parseErr) +} + +func TestPlannedOperation(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_planned_op" + defer testutil.CleanupDatabase(t, client, dbName) + + gc := gomongo.NewClient(client) + ctx := context.Background() + + // insertOne is a planned M2 operation - should return PlannedOperationError + _, err := gc.Execute(ctx, dbName, "db.users.insertOne({ name: 'test' })") + require.Error(t, err) + + var plannedErr *gomongo.PlannedOperationError + require.ErrorAs(t, err, &plannedErr) + require.Equal(t, "insertOne()", plannedErr.Operation) +} + +func TestUnsupportedOperation(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_unsup_op" + defer testutil.CleanupDatabase(t, client, dbName) + + gc := gomongo.NewClient(client) + ctx := context.Background() + + // createSearchIndex is NOT in the registry - should return UnsupportedOperationError + _, err := gc.Execute(ctx, dbName, `db.movies.createSearchIndex({ name: "default", definition: { mappings: { dynamic: true } } })`) + require.Error(t, err) + + var unsupportedErr *gomongo.UnsupportedOperationError + require.ErrorAs(t, err, &unsupportedErr) + require.Equal(t, "createSearchIndex()", unsupportedErr.Operation) +} + +func TestUnsupportedOptionError(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_unsup_opt_err" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(client) + + // find() with unsupported option 'collation' + _, err := gc.Execute(ctx, dbName, `db.users.find({}, {}, { collation: { locale: "en" } })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "find()", optErr.Method) + require.Equal(t, "collation", optErr.Option) +} + +func TestMethodRegistryStats(t *testing.T) { + total := gomongo.MethodRegistryStats() + + // Registry should contain M2 (10) + M3 (22) = 32 planned methods + require.Equal(t, 32, total, "expected 32 planned methods in registry (M2: 10, M3: 22)") + + // Log stats for visibility + t.Logf("Method Registry Stats: total=%d planned methods", total) +} + +func TestFindOneUnsupportedOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_findone_unsup_opt" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(client) + + _, err := gc.Execute(ctx, dbName, `db.users.findOne({}, {}, { collation: { locale: "en" } })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "findOne()", optErr.Method) + require.Equal(t, "collation", optErr.Option) +} + +func TestAggregateUnsupportedOption(t *testing.T) { + client := testutil.GetClient(t) + dbName := "testdb_agg_unsup_opt" + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(client) + + _, err := gc.Execute(ctx, dbName, `db.users.aggregate([], { allowDiskUse: true })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "aggregate()", optErr.Method) + require.Equal(t, "allowDiskUse", optErr.Option) +} + +func TestCountDocumentsUnsupportedOption(t *testing.T) { + dbName := "testdb_count_unsup" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(client) + + _, err := gc.Execute(ctx, dbName, `db.users.countDocuments({}, { collation: { locale: "en" } })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "countDocuments()", optErr.Method) +} + +func TestDistinctUnsupportedOption(t *testing.T) { + dbName := "testdb_distinct_unsup" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(client) + + _, err := gc.Execute(ctx, dbName, `db.users.distinct("city", {}, { collation: { locale: "en" } })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "distinct()", optErr.Method) +} + +func TestEstimatedDocumentCountUnsupportedOption(t *testing.T) { + dbName := "testdb_est_count_unsup" + client := testutil.GetClient(t) + defer testutil.CleanupDatabase(t, client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(client) + + _, err := gc.Execute(ctx, dbName, `db.users.estimatedDocumentCount({ comment: "test" })`) + var optErr *gomongo.UnsupportedOptionError + require.ErrorAs(t, err, &optErr) + require.Equal(t, "estimatedDocumentCount()", optErr.Method) + require.Equal(t, "comment", optErr.Option) +} diff --git a/errors.go b/errors.go index 29d4117..37f7f7a 100644 --- a/errors.go +++ b/errors.go @@ -1,6 +1,10 @@ package gomongo -import "fmt" +import ( + "fmt" + + "github.com/bytebase/gomongo/internal/translator" +) // ParseError represents a syntax error during parsing. type ParseError struct { @@ -47,3 +51,13 @@ type UnsupportedOptionError struct { func (e *UnsupportedOptionError) Error() string { return fmt.Sprintf("unsupported option '%s' in %s", e.Option, e.Method) } + +// MethodRegistryStats returns statistics about the method registry. +func MethodRegistryStats() int { + return translator.MethodRegistryStats() +} + +// IsPlannedMethod checks if a method is planned for future implementation. +func IsPlannedMethod(context, methodName string) bool { + return translator.IsPlannedMethod(context, methodName) +} diff --git a/executor.go b/executor.go index 2ff7f0b..4751014 100644 --- a/executor.go +++ b/executor.go @@ -2,504 +2,45 @@ package gomongo import ( "context" - "encoding/json" - "fmt" - "time" - "github.com/antlr4-go/antlr/v4" - "github.com/bytebase/parser/mongodb" - "go.mongodb.org/mongo-driver/v2/bson" + "github.com/bytebase/gomongo/internal/executor" + "github.com/bytebase/gomongo/internal/translator" "go.mongodb.org/mongo-driver/v2/mongo" - "go.mongodb.org/mongo-driver/v2/mongo/options" ) // execute parses and executes a MongoDB shell statement. func execute(ctx context.Context, client *mongo.Client, database, statement string) (*Result, error) { - // Parse the statement - tree, parseErrors := parseMongoShell(statement) - if len(parseErrors) > 0 { - err := parseErrors[0] - return nil, &ParseError{ - Line: err.Line, - Column: err.Column, - Message: err.Message, - } - } - - // Extract operation from parse tree - visitor := newMongoShellVisitor() - visitor.Visit(tree) - if visitor.err != nil { - return nil, visitor.err - } - - // Execute operation - return executeOperation(ctx, client, database, visitor.operation, statement) -} - -// parseMongoShell parses a MongoDB shell statement and returns the parse tree. -func parseMongoShell(statement string) (mongodb.IProgramContext, []*mongodb.MongoShellParseError) { - is := antlr.NewInputStream(statement) - lexer := mongodb.NewMongoShellLexer(is) - - errorListener := mongodb.NewMongoShellErrorListener() - lexer.RemoveErrorListeners() - lexer.AddErrorListener(errorListener) - - stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) - parser := mongodb.NewMongoShellParser(stream) - - parser.RemoveErrorListeners() - parser.AddErrorListener(errorListener) - - parser.BuildParseTrees = true - tree := parser.Program() - - return tree, errorListener.Errors -} - -// executeOperation executes a parsed MongoDB operation. -func executeOperation(ctx context.Context, client *mongo.Client, database string, op *mongoOperation, statement string) (*Result, error) { - switch op.opType { - case opFind: - return executeFind(ctx, client, database, op) - case opFindOne: - return executeFindOne(ctx, client, database, op) - case opAggregate: - return executeAggregate(ctx, client, database, op) - case opShowDatabases: - return executeShowDatabases(ctx, client) - case opShowCollections: - return executeShowCollections(ctx, client, database) - case opGetCollectionNames: - return executeGetCollectionNames(ctx, client, database) - case opGetCollectionInfos: - return executeGetCollectionInfos(ctx, client, database, op) - case opGetIndexes: - return executeGetIndexes(ctx, client, database, op) - case opCountDocuments: - return executeCountDocuments(ctx, client, database, op) - case opEstimatedDocumentCount: - return executeEstimatedDocumentCount(ctx, client, database, op) - case opDistinct: - return executeDistinct(ctx, client, database, op) - default: - return nil, &UnsupportedOperationError{ - Operation: statement, - } - } -} - -// executeFind executes a find operation. -func executeFind(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - filter := op.filter - if filter == nil { - filter = bson.D{} - } - - opts := options.Find() - if op.sort != nil { - opts.SetSort(op.sort) - } - if op.limit != nil { - opts.SetLimit(*op.limit) - } - if op.skip != nil { - opts.SetSkip(*op.skip) - } - if op.projection != nil { - opts.SetProjection(op.projection) - } - if op.hint != nil { - opts.SetHint(op.hint) - } - if op.max != nil { - opts.SetMax(op.max) - } - if op.min != nil { - opts.SetMin(op.min) - } - - // Apply maxTimeMS using context timeout. - // Note: MongoDB Go driver v2 removed SetMaxTime() from options. The recommended - // replacement is context.WithTimeout(). This is a client-side timeout (includes - // network latency), unlike mongosh's maxTimeMS which is server-side only. - if op.maxTimeMS != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.maxTimeMS)*time.Millisecond) - defer cancel() - } - - cursor, err := collection.Find(ctx, filter, opts) - if err != nil { - return nil, fmt.Errorf("find failed: %w", err) - } - defer func() { _ = cursor.Close(ctx) }() - - var rows []string - for cursor.Next(ctx) { - var doc bson.M - if err := cursor.Decode(&doc); err != nil { - return nil, fmt.Errorf("decode failed: %w", err) - } - - // Marshal to Extended JSON (Relaxed) - jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") - if err != nil { - return nil, fmt.Errorf("marshal failed: %w", err) - } - rows = append(rows, string(jsonBytes)) - } - - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %w", err) - } - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// executeAggregate executes an aggregation pipeline. -func executeAggregate(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - pipeline := op.pipeline - if pipeline == nil { - pipeline = bson.A{} - } - - opts := options.Aggregate() - if op.hint != nil { - opts.SetHint(op.hint) - } - - // Apply maxTimeMS using context timeout (see comment in executeFind for details). - if op.maxTimeMS != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.maxTimeMS)*time.Millisecond) - defer cancel() - } - - cursor, err := collection.Aggregate(ctx, pipeline, opts) - if err != nil { - return nil, fmt.Errorf("aggregate failed: %w", err) - } - defer func() { _ = cursor.Close(ctx) }() - - var rows []string - for cursor.Next(ctx) { - var doc bson.M - if err := cursor.Decode(&doc); err != nil { - return nil, fmt.Errorf("decode failed: %w", err) - } - - jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") - if err != nil { - return nil, fmt.Errorf("marshal failed: %w", err) - } - rows = append(rows, string(jsonBytes)) - } - - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %w", err) - } - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// executeFindOne executes a findOne operation. -func executeFindOne(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - filter := op.filter - if filter == nil { - filter = bson.D{} - } - - opts := options.FindOne() - if op.sort != nil { - opts.SetSort(op.sort) - } - if op.skip != nil { - opts.SetSkip(*op.skip) - } - if op.projection != nil { - opts.SetProjection(op.projection) - } - if op.hint != nil { - opts.SetHint(op.hint) - } - if op.max != nil { - opts.SetMax(op.max) - } - if op.min != nil { - opts.SetMin(op.min) - } - - // Apply maxTimeMS using context timeout (see comment in executeFind for details). - if op.maxTimeMS != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.maxTimeMS)*time.Millisecond) - defer cancel() - } - - var doc bson.M - err := collection.FindOne(ctx, filter, opts).Decode(&doc) + op, err := translator.Parse(statement) if err != nil { - if err == mongo.ErrNoDocuments { - return &Result{ - Rows: nil, - RowCount: 0, - }, nil + // Convert internal errors to public errors + switch e := err.(type) { + case *translator.ParseError: + return nil, &ParseError{ + Line: e.Line, + Column: e.Column, + Message: e.Message, + Found: e.Found, + Expected: e.Expected, + } + case *translator.UnsupportedOperationError: + return nil, &UnsupportedOperationError{Operation: e.Operation} + case *translator.PlannedOperationError: + return nil, &PlannedOperationError{Operation: e.Operation} + case *translator.UnsupportedOptionError: + return nil, &UnsupportedOptionError{Method: e.Method, Option: e.Option} + default: + return nil, err } - return nil, fmt.Errorf("findOne failed: %w", err) } - jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") + result, err := executor.Execute(ctx, client, database, op, statement) if err != nil { - return nil, fmt.Errorf("marshal failed: %w", err) + return nil, err } return &Result{ - Rows: []string{string(jsonBytes)}, - RowCount: 1, + Rows: result.Rows, + RowCount: result.RowCount, + Statement: result.Statement, }, nil } - -// executeShowDatabases executes a show dbs/databases command. -func executeShowDatabases(ctx context.Context, client *mongo.Client) (*Result, error) { - names, err := client.ListDatabaseNames(ctx, bson.D{}) - if err != nil { - return nil, fmt.Errorf("list databases failed: %w", err) - } - - rows := make([]string, len(names)) - copy(rows, names) - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// executeShowCollections executes a show collections command. -func executeShowCollections(ctx context.Context, client *mongo.Client, database string) (*Result, error) { - names, err := client.Database(database).ListCollectionNames(ctx, bson.D{}) - if err != nil { - return nil, fmt.Errorf("list collections failed: %w", err) - } - - rows := make([]string, len(names)) - copy(rows, names) - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// executeGetCollectionNames executes a db.getCollectionNames() command. -func executeGetCollectionNames(ctx context.Context, client *mongo.Client, database string) (*Result, error) { - return executeShowCollections(ctx, client, database) -} - -// executeGetCollectionInfos executes a db.getCollectionInfos() command. -func executeGetCollectionInfos(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - filter := op.filter - if filter == nil { - filter = bson.D{} - } - - opts := options.ListCollections() - if op.nameOnly != nil { - opts.SetNameOnly(*op.nameOnly) - } - if op.authorizedCollections != nil { - opts.SetAuthorizedCollections(*op.authorizedCollections) - } - - cursor, err := client.Database(database).ListCollections(ctx, filter, opts) - if err != nil { - return nil, fmt.Errorf("list collections failed: %w", err) - } - defer func() { _ = cursor.Close(ctx) }() - - var rows []string - for cursor.Next(ctx) { - var doc bson.M - if err := cursor.Decode(&doc); err != nil { - return nil, fmt.Errorf("decode failed: %w", err) - } - - jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") - if err != nil { - return nil, fmt.Errorf("marshal failed: %w", err) - } - rows = append(rows, string(jsonBytes)) - } - - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %w", err) - } - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// executeGetIndexes executes a db.collection.getIndexes() command. -func executeGetIndexes(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - cursor, err := collection.Indexes().List(ctx) - if err != nil { - return nil, fmt.Errorf("list indexes failed: %w", err) - } - defer func() { _ = cursor.Close(ctx) }() - - var rows []string - for cursor.Next(ctx) { - var doc bson.M - if err := cursor.Decode(&doc); err != nil { - return nil, fmt.Errorf("decode failed: %w", err) - } - - jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") - if err != nil { - return nil, fmt.Errorf("marshal failed: %w", err) - } - rows = append(rows, string(jsonBytes)) - } - - if err := cursor.Err(); err != nil { - return nil, fmt.Errorf("cursor error: %w", err) - } - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// executeCountDocuments executes a db.collection.countDocuments() command. -func executeCountDocuments(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - filter := op.filter - if filter == nil { - filter = bson.D{} - } - - opts := options.Count() - if op.hint != nil { - opts.SetHint(op.hint) - } - if op.limit != nil { - opts.SetLimit(*op.limit) - } - if op.skip != nil { - opts.SetSkip(*op.skip) - } - - // Apply maxTimeMS using context timeout (see comment in executeFind for details). - if op.maxTimeMS != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.maxTimeMS)*time.Millisecond) - defer cancel() - } - - count, err := collection.CountDocuments(ctx, filter, opts) - if err != nil { - return nil, fmt.Errorf("count documents failed: %w", err) - } - - return &Result{ - Rows: []string{fmt.Sprintf("%d", count)}, - RowCount: 1, - }, nil -} - -// executeEstimatedDocumentCount executes a db.collection.estimatedDocumentCount() command. -func executeEstimatedDocumentCount(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - // Apply maxTimeMS using context timeout (see comment in executeFind for details). - if op.maxTimeMS != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.maxTimeMS)*time.Millisecond) - defer cancel() - } - - count, err := collection.EstimatedDocumentCount(ctx) - if err != nil { - return nil, fmt.Errorf("estimated document count failed: %w", err) - } - - return &Result{ - Rows: []string{fmt.Sprintf("%d", count)}, - RowCount: 1, - }, nil -} - -// executeDistinct executes a db.collection.distinct() command. -func executeDistinct(ctx context.Context, client *mongo.Client, database string, op *mongoOperation) (*Result, error) { - collection := client.Database(database).Collection(op.collection) - - filter := op.filter - if filter == nil { - filter = bson.D{} - } - - // Apply maxTimeMS using context timeout (see comment in executeFind for details). - if op.maxTimeMS != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.maxTimeMS)*time.Millisecond) - defer cancel() - } - - result := collection.Distinct(ctx, op.distinctField, filter) - if err := result.Err(); err != nil { - return nil, fmt.Errorf("distinct failed: %w", err) - } - - var values []any - if err := result.Decode(&values); err != nil { - return nil, fmt.Errorf("decode failed: %w", err) - } - - var rows []string - for _, val := range values { - jsonBytes, err := marshalValue(val) - if err != nil { - return nil, fmt.Errorf("marshal failed: %w", err) - } - rows = append(rows, string(jsonBytes)) - } - - return &Result{ - Rows: rows, - RowCount: len(rows), - }, nil -} - -// marshalValue marshals a value to JSON. -// bson.MarshalExtJSONIndent only works for documents/arrays at top level, -// so we use encoding/json for primitive values (strings, numbers, booleans). -func marshalValue(val any) ([]byte, error) { - switch v := val.(type) { - case bson.M, bson.D, map[string]any: - return bson.MarshalExtJSONIndent(v, false, false, "", " ") - case bson.A, []any: - return bson.MarshalExtJSONIndent(v, false, false, "", " ") - default: - return json.Marshal(v) - } -} diff --git a/internal/executor/collection.go b/internal/executor/collection.go new file mode 100644 index 0000000..0ac143c --- /dev/null +++ b/internal/executor/collection.go @@ -0,0 +1,345 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/bytebase/gomongo/internal/translator" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// executeFind executes a find operation. +func executeFind(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + filter := op.Filter + if filter == nil { + filter = bson.D{} + } + + opts := options.Find() + if op.Sort != nil { + opts.SetSort(op.Sort) + } + if op.Limit != nil { + opts.SetLimit(*op.Limit) + } + if op.Skip != nil { + opts.SetSkip(*op.Skip) + } + if op.Projection != nil { + opts.SetProjection(op.Projection) + } + if op.Hint != nil { + opts.SetHint(op.Hint) + } + if op.Max != nil { + opts.SetMax(op.Max) + } + if op.Min != nil { + opts.SetMin(op.Min) + } + + // Apply maxTimeMS using context timeout. + // Note: MongoDB Go driver v2 removed SetMaxTime() from options. The recommended + // replacement is context.WithTimeout(). This is a client-side timeout (includes + // network latency), unlike mongosh's maxTimeMS which is server-side only. + if op.MaxTimeMS != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.MaxTimeMS)*time.Millisecond) + defer cancel() + } + + cursor, err := collection.Find(ctx, filter, opts) + if err != nil { + return nil, fmt.Errorf("find failed: %w", err) + } + defer func() { _ = cursor.Close(ctx) }() + + var rows []string + for cursor.Next(ctx) { + var doc bson.M + if err := cursor.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode failed: %w", err) + } + + // Marshal to Extended JSON (Relaxed) + jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal failed: %w", err) + } + rows = append(rows, string(jsonBytes)) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %w", err) + } + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} + +// executeFindOne executes a findOne operation. +func executeFindOne(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + filter := op.Filter + if filter == nil { + filter = bson.D{} + } + + opts := options.FindOne() + if op.Sort != nil { + opts.SetSort(op.Sort) + } + if op.Skip != nil { + opts.SetSkip(*op.Skip) + } + if op.Projection != nil { + opts.SetProjection(op.Projection) + } + if op.Hint != nil { + opts.SetHint(op.Hint) + } + if op.Max != nil { + opts.SetMax(op.Max) + } + if op.Min != nil { + opts.SetMin(op.Min) + } + + // Apply maxTimeMS using context timeout (see comment in executeFind for details). + if op.MaxTimeMS != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.MaxTimeMS)*time.Millisecond) + defer cancel() + } + + var doc bson.M + err := collection.FindOne(ctx, filter, opts).Decode(&doc) + if err != nil { + if err == mongo.ErrNoDocuments { + return &Result{ + Rows: nil, + RowCount: 0, + }, nil + } + return nil, fmt.Errorf("findOne failed: %w", err) + } + + jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal failed: %w", err) + } + + return &Result{ + Rows: []string{string(jsonBytes)}, + RowCount: 1, + }, nil +} + +// executeAggregate executes an aggregation pipeline. +func executeAggregate(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + pipeline := op.Pipeline + if pipeline == nil { + pipeline = bson.A{} + } + + opts := options.Aggregate() + if op.Hint != nil { + opts.SetHint(op.Hint) + } + + // Apply maxTimeMS using context timeout (see comment in executeFind for details). + if op.MaxTimeMS != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.MaxTimeMS)*time.Millisecond) + defer cancel() + } + + cursor, err := collection.Aggregate(ctx, pipeline, opts) + if err != nil { + return nil, fmt.Errorf("aggregate failed: %w", err) + } + defer func() { _ = cursor.Close(ctx) }() + + var rows []string + for cursor.Next(ctx) { + var doc bson.M + if err := cursor.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode failed: %w", err) + } + + jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal failed: %w", err) + } + rows = append(rows, string(jsonBytes)) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %w", err) + } + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} + +// executeGetIndexes executes a db.collection.getIndexes() command. +func executeGetIndexes(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + cursor, err := collection.Indexes().List(ctx) + if err != nil { + return nil, fmt.Errorf("list indexes failed: %w", err) + } + defer func() { _ = cursor.Close(ctx) }() + + var rows []string + for cursor.Next(ctx) { + var doc bson.M + if err := cursor.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode failed: %w", err) + } + + jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal failed: %w", err) + } + rows = append(rows, string(jsonBytes)) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %w", err) + } + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} + +// executeCountDocuments executes a db.collection.countDocuments() command. +func executeCountDocuments(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + filter := op.Filter + if filter == nil { + filter = bson.D{} + } + + opts := options.Count() + if op.Hint != nil { + opts.SetHint(op.Hint) + } + if op.Limit != nil { + opts.SetLimit(*op.Limit) + } + if op.Skip != nil { + opts.SetSkip(*op.Skip) + } + + // Apply maxTimeMS using context timeout (see comment in executeFind for details). + if op.MaxTimeMS != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.MaxTimeMS)*time.Millisecond) + defer cancel() + } + + count, err := collection.CountDocuments(ctx, filter, opts) + if err != nil { + return nil, fmt.Errorf("count documents failed: %w", err) + } + + return &Result{ + Rows: []string{fmt.Sprintf("%d", count)}, + RowCount: 1, + }, nil +} + +// executeEstimatedDocumentCount executes a db.collection.estimatedDocumentCount() command. +func executeEstimatedDocumentCount(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + // Apply maxTimeMS using context timeout (see comment in executeFind for details). + if op.MaxTimeMS != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.MaxTimeMS)*time.Millisecond) + defer cancel() + } + + count, err := collection.EstimatedDocumentCount(ctx) + if err != nil { + return nil, fmt.Errorf("estimated document count failed: %w", err) + } + + return &Result{ + Rows: []string{fmt.Sprintf("%d", count)}, + RowCount: 1, + }, nil +} + +// executeDistinct executes a db.collection.distinct() command. +func executeDistinct(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + collection := client.Database(database).Collection(op.Collection) + + filter := op.Filter + if filter == nil { + filter = bson.D{} + } + + // Apply maxTimeMS using context timeout (see comment in executeFind for details). + if op.MaxTimeMS != nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(*op.MaxTimeMS)*time.Millisecond) + defer cancel() + } + + result := collection.Distinct(ctx, op.DistinctField, filter) + if err := result.Err(); err != nil { + return nil, fmt.Errorf("distinct failed: %w", err) + } + + var values []any + if err := result.Decode(&values); err != nil { + return nil, fmt.Errorf("decode failed: %w", err) + } + + var rows []string + for _, val := range values { + jsonBytes, err := marshalValue(val) + if err != nil { + return nil, fmt.Errorf("marshal failed: %w", err) + } + rows = append(rows, string(jsonBytes)) + } + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} + +// marshalValue marshals a value to JSON. +// bson.MarshalExtJSONIndent only works for documents/arrays at top level, +// so we use encoding/json for primitive values (strings, numbers, booleans). +func marshalValue(val any) ([]byte, error) { + switch v := val.(type) { + case bson.M, bson.D, map[string]any: + return bson.MarshalExtJSONIndent(v, false, false, "", " ") + case bson.A, []any: + return bson.MarshalExtJSONIndent(v, false, false, "", " ") + default: + return json.Marshal(v) + } +} diff --git a/internal/executor/database.go b/internal/executor/database.go new file mode 100644 index 0000000..d1bc257 --- /dev/null +++ b/internal/executor/database.go @@ -0,0 +1,77 @@ +package executor + +import ( + "context" + "fmt" + + "github.com/bytebase/gomongo/internal/translator" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// executeShowCollections executes a show collections command. +func executeShowCollections(ctx context.Context, client *mongo.Client, database string) (*Result, error) { + names, err := client.Database(database).ListCollectionNames(ctx, bson.D{}) + if err != nil { + return nil, fmt.Errorf("list collections failed: %w", err) + } + + rows := make([]string, len(names)) + copy(rows, names) + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} + +// executeGetCollectionNames executes a db.getCollectionNames() command. +func executeGetCollectionNames(ctx context.Context, client *mongo.Client, database string) (*Result, error) { + return executeShowCollections(ctx, client, database) +} + +// executeGetCollectionInfos executes a db.getCollectionInfos() command. +func executeGetCollectionInfos(ctx context.Context, client *mongo.Client, database string, op *translator.Operation) (*Result, error) { + filter := op.Filter + if filter == nil { + filter = bson.D{} + } + + opts := options.ListCollections() + if op.NameOnly != nil { + opts.SetNameOnly(*op.NameOnly) + } + if op.AuthorizedCollections != nil { + opts.SetAuthorizedCollections(*op.AuthorizedCollections) + } + + cursor, err := client.Database(database).ListCollections(ctx, filter, opts) + if err != nil { + return nil, fmt.Errorf("list collections failed: %w", err) + } + defer func() { _ = cursor.Close(ctx) }() + + var rows []string + for cursor.Next(ctx) { + var doc bson.M + if err := cursor.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode failed: %w", err) + } + + jsonBytes, err := bson.MarshalExtJSONIndent(doc, false, false, "", " ") + if err != nil { + return nil, fmt.Errorf("marshal failed: %w", err) + } + rows = append(rows, string(jsonBytes)) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("cursor error: %w", err) + } + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} diff --git a/internal/executor/executor.go b/internal/executor/executor.go new file mode 100644 index 0000000..6a1d139 --- /dev/null +++ b/internal/executor/executor.go @@ -0,0 +1,46 @@ +package executor + +import ( + "context" + "fmt" + + "github.com/bytebase/gomongo/internal/translator" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +// Result represents query execution results. +type Result struct { + Rows []string + RowCount int + Statement string +} + +// Execute executes a parsed operation against MongoDB. +func Execute(ctx context.Context, client *mongo.Client, database string, op *translator.Operation, statement string) (*Result, error) { + switch op.OpType { + case translator.OpFind: + return executeFind(ctx, client, database, op) + case translator.OpFindOne: + return executeFindOne(ctx, client, database, op) + case translator.OpAggregate: + return executeAggregate(ctx, client, database, op) + case translator.OpShowDatabases: + return executeShowDatabases(ctx, client) + case translator.OpShowCollections: + return executeShowCollections(ctx, client, database) + case translator.OpGetCollectionNames: + return executeGetCollectionNames(ctx, client, database) + case translator.OpGetCollectionInfos: + return executeGetCollectionInfos(ctx, client, database, op) + case translator.OpGetIndexes: + return executeGetIndexes(ctx, client, database, op) + case translator.OpCountDocuments: + return executeCountDocuments(ctx, client, database, op) + case translator.OpEstimatedDocumentCount: + return executeEstimatedDocumentCount(ctx, client, database, op) + case translator.OpDistinct: + return executeDistinct(ctx, client, database, op) + default: + return nil, fmt.Errorf("unsupported operation: %s", statement) + } +} diff --git a/internal/executor/server.go b/internal/executor/server.go new file mode 100644 index 0000000..bb3c236 --- /dev/null +++ b/internal/executor/server.go @@ -0,0 +1,25 @@ +package executor + +import ( + "context" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +// executeShowDatabases executes a show dbs/databases command. +func executeShowDatabases(ctx context.Context, client *mongo.Client) (*Result, error) { + names, err := client.ListDatabaseNames(ctx, bson.D{}) + if err != nil { + return nil, fmt.Errorf("list databases failed: %w", err) + } + + rows := make([]string, len(names)) + copy(rows, names) + + return &Result{ + Rows: rows, + RowCount: len(rows), + }, nil +} diff --git a/internal/testutil/container.go b/internal/testutil/container.go new file mode 100644 index 0000000..7b79403 --- /dev/null +++ b/internal/testutil/container.go @@ -0,0 +1,59 @@ +package testutil + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/testcontainers/testcontainers-go/modules/mongodb" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +var ( + container *mongodb.MongoDBContainer + client *mongo.Client + containerOnce sync.Once + containerErr error +) + +// GetClient returns a shared MongoDB client for testing. +// The container is started once and reused across all tests. +// Each test should use a unique database name to avoid interference. +func GetClient(t *testing.T) *mongo.Client { + t.Helper() + + containerOnce.Do(func() { + ctx := context.Background() + + container, containerErr = mongodb.Run(ctx, "mongo:7") + if containerErr != nil { + return + } + + connectionString, err := container.ConnectionString(ctx) + if err != nil { + containerErr = err + return + } + + client, containerErr = mongo.Connect(options.Client().ApplyURI(connectionString)) + }) + + if containerErr != nil { + t.Fatalf("failed to setup test container: %v", containerErr) + } + + return client +} + +// CleanupDatabase drops the specified database after a test. +func CleanupDatabase(t *testing.T, client *mongo.Client, dbName string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := client.Database(dbName).Drop(ctx); err != nil { + t.Logf("warning: failed to drop database %s: %v", dbName, err) + } +} diff --git a/helper_functions.go b/internal/translator/bson_helpers.go similarity index 85% rename from helper_functions.go rename to internal/translator/bson_helpers.go index 5ba7350..0a39aa4 100644 --- a/helper_functions.go +++ b/internal/translator/bson_helpers.go @@ -1,4 +1,4 @@ -package gomongo +package translator import ( "encoding/hex" @@ -11,6 +11,44 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" ) +// convertHelperFunction converts a helper function to a BSON value. +func convertHelperFunction(ctx mongodb.IHelperFunctionContext) (any, error) { + helper, ok := ctx.(*mongodb.HelperFunctionContext) + if !ok { + return nil, fmt.Errorf("invalid helper function context") + } + + if helper.ObjectIdHelper() != nil { + return convertObjectIdHelper(helper.ObjectIdHelper()) + } + if helper.IsoDateHelper() != nil { + return convertIsoDateHelper(helper.IsoDateHelper()) + } + if helper.DateHelper() != nil { + return convertDateHelper(helper.DateHelper()) + } + if helper.UuidHelper() != nil { + return convertUuidHelper(helper.UuidHelper()) + } + if helper.LongHelper() != nil { + return convertLongHelper(helper.LongHelper()) + } + if helper.Int32Helper() != nil { + return convertInt32Helper(helper.Int32Helper()) + } + if helper.DoubleHelper() != nil { + return convertDoubleHelper(helper.DoubleHelper()) + } + if helper.Decimal128Helper() != nil { + return convertDecimal128Helper(helper.Decimal128Helper()) + } + if helper.TimestampHelper() != nil { + return convertTimestampHelper(helper.TimestampHelper()) + } + + return nil, fmt.Errorf("unsupported helper function") +} + // convertObjectIdHelper converts ObjectId("hex") to primitive.ObjectID. func convertObjectIdHelper(ctx mongodb.IObjectIdHelperContext) (bson.ObjectID, error) { helper, ok := ctx.(*mongodb.ObjectIdHelperContext) diff --git a/internal/translator/collection.go b/internal/translator/collection.go new file mode 100644 index 0000000..ebfd62e --- /dev/null +++ b/internal/translator/collection.go @@ -0,0 +1,842 @@ +package translator + +import ( + "fmt" + "strconv" + + "github.com/bytebase/parser/mongodb" + "go.mongodb.org/mongo-driver/v2/bson" +) + +func (v *visitor) extractFindArgs(ctx mongodb.IFindMethodContext) { + fm, ok := ctx.(*mongodb.FindMethodContext) + if !ok { + return + } + + args := fm.Arguments() + if args == nil { + return + } + + argsCtx, ok := args.(*mongodb.ArgumentsContext) + if !ok { + return + } + + allArgs := argsCtx.AllArgument() + if len(allArgs) == 0 { + return + } + + // First argument: filter + firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) + if !ok { + return + } + valueCtx := firstArg.Value() + if valueCtx != nil { + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("find() filter must be a document") + return + } + filter, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid filter: %w", err) + return + } + v.operation.Filter = filter + } + + // Second argument: projection (optional) + if len(allArgs) >= 2 { + secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) + if !ok { + return + } + valueCtx := secondArg.Value() + if valueCtx != nil { + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("find() projection must be a document") + return + } + projection, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid projection: %w", err) + return + } + v.operation.Projection = projection + } + } + + // Third argument: options (optional) + if len(allArgs) >= 3 { + thirdArg, ok := allArgs[2].(*mongodb.ArgumentContext) + if !ok { + return + } + valueCtx := thirdArg.Value() + if valueCtx != nil { + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("find() options must be a document") + return + } + options, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + // Validate and extract supported options + for _, opt := range options { + switch opt.Key { + case "hint": + v.operation.Hint = opt.Value + case "max": + if doc, ok := opt.Value.(bson.D); ok { + v.operation.Max = doc + } else { + v.err = fmt.Errorf("find() max must be a document") + return + } + case "min": + if doc, ok := opt.Value.(bson.D); ok { + v.operation.Min = doc + } else { + v.err = fmt.Errorf("find() min must be a document") + return + } + case "maxTimeMS": + if val, ok := opt.Value.(int32); ok { + ms := int64(val) + v.operation.MaxTimeMS = &ms + } else if val, ok := opt.Value.(int64); ok { + v.operation.MaxTimeMS = &val + } else { + v.err = fmt.Errorf("find() maxTimeMS must be a number") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "find()", + Option: opt.Key, + } + return + } + } + } + } + + // More than 3 arguments is an error + if len(allArgs) > 3 { + v.err = fmt.Errorf("find() takes at most 3 arguments") + return + } +} + +func (v *visitor) extractFindOneArgs(ctx mongodb.IFindOneMethodContext) { + fm, ok := ctx.(*mongodb.FindOneMethodContext) + if !ok { + return + } + + args := fm.Arguments() + if args == nil { + return + } + + argsCtx, ok := args.(*mongodb.ArgumentsContext) + if !ok { + return + } + + allArgs := argsCtx.AllArgument() + if len(allArgs) == 0 { + return + } + + // First argument: filter + firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) + if !ok { + return + } + valueCtx := firstArg.Value() + if valueCtx != nil { + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("findOne() filter must be a document") + return + } + filter, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid filter: %w", err) + return + } + v.operation.Filter = filter + } + + // Second argument: projection (optional) + if len(allArgs) >= 2 { + secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) + if !ok { + return + } + valueCtx := secondArg.Value() + if valueCtx != nil { + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("findOne() projection must be a document") + return + } + projection, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid projection: %w", err) + return + } + v.operation.Projection = projection + } + } + + // Third argument: options (optional) + if len(allArgs) >= 3 { + thirdArg, ok := allArgs[2].(*mongodb.ArgumentContext) + if !ok { + return + } + valueCtx := thirdArg.Value() + if valueCtx != nil { + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("findOne() options must be a document") + return + } + options, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + // Validate and extract supported options + for _, opt := range options { + switch opt.Key { + case "hint": + v.operation.Hint = opt.Value + case "max": + if doc, ok := opt.Value.(bson.D); ok { + v.operation.Max = doc + } else { + v.err = fmt.Errorf("findOne() max must be a document") + return + } + case "min": + if doc, ok := opt.Value.(bson.D); ok { + v.operation.Min = doc + } else { + v.err = fmt.Errorf("findOne() min must be a document") + return + } + case "maxTimeMS": + if val, ok := opt.Value.(int32); ok { + ms := int64(val) + v.operation.MaxTimeMS = &ms + } else if val, ok := opt.Value.(int64); ok { + v.operation.MaxTimeMS = &val + } else { + v.err = fmt.Errorf("findOne() maxTimeMS must be a number") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "findOne()", + Option: opt.Key, + } + return + } + } + } + } + + // More than 3 arguments is an error + if len(allArgs) > 3 { + v.err = fmt.Errorf("findOne() takes at most 3 arguments") + return + } +} + +func (v *visitor) extractSort(ctx mongodb.ISortMethodContext) { + sm, ok := ctx.(*mongodb.SortMethodContext) + if !ok { + return + } + + doc := sm.Document() + if doc == nil { + v.err = fmt.Errorf("sort() requires a document argument") + return + } + + sort, err := convertDocument(doc) + if err != nil { + v.err = fmt.Errorf("invalid sort: %w", err) + return + } + v.operation.Sort = sort +} + +func (v *visitor) extractLimit(ctx mongodb.ILimitMethodContext) { + lm, ok := ctx.(*mongodb.LimitMethodContext) + if !ok { + return + } + + numNode := lm.NUMBER() + if numNode == nil { + v.err = fmt.Errorf("limit() requires a number argument") + return + } + + limit, err := strconv.ParseInt(numNode.GetText(), 10, 64) + if err != nil { + v.err = fmt.Errorf("invalid limit: %w", err) + return + } + v.operation.Limit = &limit +} + +func (v *visitor) extractSkip(ctx mongodb.ISkipMethodContext) { + sm, ok := ctx.(*mongodb.SkipMethodContext) + if !ok { + return + } + + numNode := sm.NUMBER() + if numNode == nil { + v.err = fmt.Errorf("skip() requires a number argument") + return + } + + skip, err := strconv.ParseInt(numNode.GetText(), 10, 64) + if err != nil { + v.err = fmt.Errorf("invalid skip: %w", err) + return + } + v.operation.Skip = &skip +} + +func (v *visitor) extractProjection(ctx mongodb.IProjectionMethodContext) { + pm, ok := ctx.(*mongodb.ProjectionMethodContext) + if !ok { + return + } + + doc := pm.Document() + if doc == nil { + v.err = fmt.Errorf("projection() requires a document argument") + return + } + + projection, err := convertDocument(doc) + if err != nil { + v.err = fmt.Errorf("invalid projection: %w", err) + return + } + v.operation.Projection = projection +} + +func (v *visitor) extractHint(ctx mongodb.IHintMethodContext) { + hm, ok := ctx.(*mongodb.HintMethodContext) + if !ok { + return + } + + arg := hm.Argument() + if arg == nil { + v.err = fmt.Errorf("hint() requires an argument") + return + } + + argCtx, ok := arg.(*mongodb.ArgumentContext) + if !ok { + return + } + + valueCtx := argCtx.Value() + if valueCtx == nil { + return + } + + // hint can be a string (index name) or document (index spec) + switch val := valueCtx.(type) { + case *mongodb.LiteralValueContext: + strLit, ok := val.Literal().(*mongodb.StringLiteralValueContext) + if !ok { + v.err = fmt.Errorf("hint() argument must be a string or document") + return + } + v.operation.Hint = unquoteString(strLit.StringLiteral().GetText()) + case *mongodb.DocumentValueContext: + doc, err := convertDocument(val.Document()) + if err != nil { + v.err = fmt.Errorf("invalid hint: %w", err) + return + } + v.operation.Hint = doc + default: + v.err = fmt.Errorf("hint() argument must be a string or document") + } +} + +func (v *visitor) extractMax(ctx mongodb.IMaxMethodContext) { + mm, ok := ctx.(*mongodb.MaxMethodContext) + if !ok { + return + } + + doc := mm.Document() + if doc == nil { + v.err = fmt.Errorf("max() requires a document argument") + return + } + + maxDoc, err := convertDocument(doc) + if err != nil { + v.err = fmt.Errorf("invalid max: %w", err) + return + } + v.operation.Max = maxDoc +} + +func (v *visitor) extractMin(ctx mongodb.IMinMethodContext) { + mm, ok := ctx.(*mongodb.MinMethodContext) + if !ok { + return + } + + doc := mm.Document() + if doc == nil { + v.err = fmt.Errorf("min() requires a document argument") + return + } + + minDoc, err := convertDocument(doc) + if err != nil { + v.err = fmt.Errorf("invalid min: %w", err) + return + } + v.operation.Min = minDoc +} + +// extractAggregationPipelineFromMethod extracts pipeline from AggregateMethodContext. +func (v *visitor) extractAggregationPipelineFromMethod(ctx mongodb.IAggregateMethodContext) { + method, ok := ctx.(*mongodb.AggregateMethodContext) + if !ok { + return + } + v.extractArgumentsForAggregate(method.Arguments()) +} + +// extractArgumentsForAggregate extracts aggregate pipeline from IArgumentsContext. +func (v *visitor) extractArgumentsForAggregate(args mongodb.IArgumentsContext) { + if args == nil { + // Empty pipeline: aggregate() + v.operation.Pipeline = bson.A{} + return + } + + argsCtx, ok := args.(*mongodb.ArgumentsContext) + if !ok { + v.err = fmt.Errorf("aggregate() requires an array argument") + return + } + + allArgs := argsCtx.AllArgument() + if len(allArgs) == 0 { + v.operation.Pipeline = bson.A{} + return + } + + // First argument should be the pipeline array + firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) + if !ok { + v.err = fmt.Errorf("aggregate() requires an array argument") + return + } + + valueCtx := firstArg.Value() + if valueCtx == nil { + v.err = fmt.Errorf("aggregate() requires an array argument") + return + } + + arrayValue, ok := valueCtx.(*mongodb.ArrayValueContext) + if !ok { + v.err = fmt.Errorf("aggregate() requires an array argument, got %T", valueCtx) + return + } + + pipeline, err := convertArray(arrayValue.Array()) + if err != nil { + v.err = fmt.Errorf("invalid aggregation pipeline: %w", err) + return + } + + v.operation.Pipeline = pipeline + + // Second argument: options (optional) + if len(allArgs) >= 2 { + secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) + if !ok { + return + } + optionsValueCtx := secondArg.Value() + if optionsValueCtx == nil { + return + } + docValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("aggregate() options must be a document") + return + } + options, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + for _, opt := range options { + switch opt.Key { + case "hint": + v.operation.Hint = opt.Value + case "maxTimeMS": + if val, ok := opt.Value.(int32); ok { + ms := int64(val) + v.operation.MaxTimeMS = &ms + } else if val, ok := opt.Value.(int64); ok { + v.operation.MaxTimeMS = &val + } else { + v.err = fmt.Errorf("aggregate() maxTimeMS must be a number") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "aggregate()", + Option: opt.Key, + } + return + } + } + } + + // More than 2 arguments is an error + if len(allArgs) > 2 { + v.err = fmt.Errorf("aggregate() takes at most 2 arguments") + return + } +} + +// extractCountDocumentsArgsFromMethod extracts arguments from CountDocumentsMethodContext. +func (v *visitor) extractCountDocumentsArgsFromMethod(ctx mongodb.ICountDocumentsMethodContext) { + method, ok := ctx.(*mongodb.CountDocumentsMethodContext) + if !ok { + return + } + v.extractArgumentsForCountDocuments(method.Arguments()) +} + +// extractArgumentsForCountDocuments extracts countDocuments arguments from IArgumentsContext. +func (v *visitor) extractArgumentsForCountDocuments(args mongodb.IArgumentsContext) { + if args == nil { + return + } + + argsCtx, ok := args.(*mongodb.ArgumentsContext) + if !ok { + return + } + + allArgs := argsCtx.AllArgument() + if len(allArgs) == 0 { + return + } + + // First argument is the filter (optional) + firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) + if !ok { + return + } + + valueCtx := firstArg.Value() + if valueCtx == nil { + return + } + + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("countDocuments() filter must be a document") + return + } + + filter, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid filter: %w", err) + return + } + v.operation.Filter = filter + + // Second argument is the options (optional) + if len(allArgs) < 2 { + return + } + + secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) + if !ok { + return + } + + optionsValueCtx := secondArg.Value() + if optionsValueCtx == nil { + return + } + + optionsDocValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("countDocuments() options must be a document") + return + } + + optionsDoc, err := convertDocument(optionsDocValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + + // Extract supported options: hint, limit, skip, maxTimeMS + for _, elem := range optionsDoc { + switch elem.Key { + case "hint": + v.operation.Hint = elem.Value + case "limit": + if val, ok := elem.Value.(int32); ok { + limit := int64(val) + v.operation.Limit = &limit + } else if val, ok := elem.Value.(int64); ok { + v.operation.Limit = &val + } + case "skip": + if val, ok := elem.Value.(int32); ok { + skip := int64(val) + v.operation.Skip = &skip + } else if val, ok := elem.Value.(int64); ok { + v.operation.Skip = &val + } + case "maxTimeMS": + if val, ok := elem.Value.(int32); ok { + ms := int64(val) + v.operation.MaxTimeMS = &ms + } else if val, ok := elem.Value.(int64); ok { + v.operation.MaxTimeMS = &val + } else { + v.err = fmt.Errorf("countDocuments() maxTimeMS must be a number") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "countDocuments()", + Option: elem.Key, + } + return + } + } +} + +// extractEstimatedDocumentCountArgs extracts arguments from EstimatedDocumentCountMethodContext. +func (v *visitor) extractEstimatedDocumentCountArgs(ctx mongodb.IEstimatedDocumentCountMethodContext) { + method, ok := ctx.(*mongodb.EstimatedDocumentCountMethodContext) + if !ok { + return + } + + // EstimatedDocumentCountMethodContext has Argument() (singular) that returns a single optional argument + arg := method.Argument() + if arg == nil { + return + } + + argCtx, ok := arg.(*mongodb.ArgumentContext) + if !ok { + return + } + + valueCtx := argCtx.Value() + if valueCtx == nil { + return + } + + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("estimatedDocumentCount() options must be a document") + return + } + + options, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + + for _, opt := range options { + switch opt.Key { + case "maxTimeMS": + if val, ok := opt.Value.(int32); ok { + ms := int64(val) + v.operation.MaxTimeMS = &ms + } else if val, ok := opt.Value.(int64); ok { + v.operation.MaxTimeMS = &val + } else { + v.err = fmt.Errorf("estimatedDocumentCount() maxTimeMS must be a number") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "estimatedDocumentCount()", + Option: opt.Key, + } + return + } + } +} + +// extractDistinctArgsFromMethod extracts arguments from DistinctMethodContext. +func (v *visitor) extractDistinctArgsFromMethod(ctx mongodb.IDistinctMethodContext) { + method, ok := ctx.(*mongodb.DistinctMethodContext) + if !ok { + return + } + v.extractArgumentsForDistinct(method.Arguments()) +} + +// extractArgumentsForDistinct extracts distinct arguments from IArgumentsContext. +func (v *visitor) extractArgumentsForDistinct(args mongodb.IArgumentsContext) { + if args == nil { + v.err = fmt.Errorf("distinct() requires a field name argument") + return + } + + argsCtx, ok := args.(*mongodb.ArgumentsContext) + if !ok { + v.err = fmt.Errorf("distinct() requires a field name argument") + return + } + + allArgs := argsCtx.AllArgument() + if len(allArgs) == 0 { + v.err = fmt.Errorf("distinct() requires a field name argument") + return + } + + // First argument is the field name (required) + firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) + if !ok { + v.err = fmt.Errorf("distinct() requires a field name argument") + return + } + + valueCtx := firstArg.Value() + if valueCtx == nil { + v.err = fmt.Errorf("distinct() requires a field name argument") + return + } + + literalValue, ok := valueCtx.(*mongodb.LiteralValueContext) + if !ok { + v.err = fmt.Errorf("distinct() field name must be a string") + return + } + + stringLiteral, ok := literalValue.Literal().(*mongodb.StringLiteralValueContext) + if !ok { + v.err = fmt.Errorf("distinct() field name must be a string") + return + } + + v.operation.DistinctField = unquoteString(stringLiteral.StringLiteral().GetText()) + + // Second argument is the filter (optional) + if len(allArgs) < 2 { + return + } + + secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) + if !ok { + return + } + + filterValueCtx := secondArg.Value() + if filterValueCtx == nil { + return + } + + docValue, ok := filterValueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("distinct() filter must be a document") + return + } + + filter, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid filter: %w", err) + return + } + v.operation.Filter = filter + + // Third argument: options (optional) + if len(allArgs) >= 3 { + thirdArg, ok := allArgs[2].(*mongodb.ArgumentContext) + if !ok { + return + } + + optionsValueCtx := thirdArg.Value() + if optionsValueCtx == nil { + return + } + + docValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("distinct() options must be a document") + return + } + + options, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + + for _, opt := range options { + switch opt.Key { + case "maxTimeMS": + if val, ok := opt.Value.(int32); ok { + ms := int64(val) + v.operation.MaxTimeMS = &ms + } else if val, ok := opt.Value.(int64); ok { + v.operation.MaxTimeMS = &val + } else { + v.err = fmt.Errorf("distinct() maxTimeMS must be a number") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "distinct()", + Option: opt.Key, + } + return + } + } + } + + if len(allArgs) > 3 { + v.err = fmt.Errorf("distinct() takes at most 3 arguments") + return + } +} diff --git a/internal/translator/database.go b/internal/translator/database.go new file mode 100644 index 0000000..b6989dc --- /dev/null +++ b/internal/translator/database.go @@ -0,0 +1,103 @@ +package translator + +import ( + "fmt" + + "github.com/bytebase/parser/mongodb" +) + +func (v *visitor) extractGetCollectionInfosArgs(ctx *mongodb.GetCollectionInfosContext) { + args := ctx.Arguments() + if args == nil { + return + } + + argsCtx, ok := args.(*mongodb.ArgumentsContext) + if !ok { + return + } + + allArgs := argsCtx.AllArgument() + if len(allArgs) == 0 { + return + } + + // First argument is the filter (optional) + firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) + if !ok { + return + } + + valueCtx := firstArg.Value() + if valueCtx == nil { + return + } + + docValue, ok := valueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("getCollectionInfos() filter must be a document") + return + } + + filter, err := convertDocument(docValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid filter: %w", err) + return + } + v.operation.Filter = filter + + // Second argument is the options (optional) + if len(allArgs) >= 2 { + secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) + if !ok { + return + } + + optionsValueCtx := secondArg.Value() + if optionsValueCtx == nil { + return + } + + optionsDocValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) + if !ok { + v.err = fmt.Errorf("getCollectionInfos() options must be a document") + return + } + + optionsDoc, err := convertDocument(optionsDocValue.Document()) + if err != nil { + v.err = fmt.Errorf("invalid options: %w", err) + return + } + + for _, opt := range optionsDoc { + switch opt.Key { + case "nameOnly": + if val, ok := opt.Value.(bool); ok { + v.operation.NameOnly = &val + } else { + v.err = fmt.Errorf("getCollectionInfos() nameOnly must be a boolean") + return + } + case "authorizedCollections": + if val, ok := opt.Value.(bool); ok { + v.operation.AuthorizedCollections = &val + } else { + v.err = fmt.Errorf("getCollectionInfos() authorizedCollections must be a boolean") + return + } + default: + v.err = &UnsupportedOptionError{ + Method: "getCollectionInfos()", + Option: opt.Key, + } + return + } + } + } + + if len(allArgs) > 2 { + v.err = fmt.Errorf("getCollectionInfos() takes at most 2 arguments") + return + } +} diff --git a/internal/translator/errors.go b/internal/translator/errors.go new file mode 100644 index 0000000..6beb2be --- /dev/null +++ b/internal/translator/errors.go @@ -0,0 +1,49 @@ +package translator + +import "fmt" + +// ParseError represents a syntax error during parsing. +type ParseError struct { + Line int + Column int + Message string + Found string + Expected string +} + +func (e *ParseError) Error() string { + if e.Found != "" && e.Expected != "" { + return fmt.Sprintf("parse error at line %d, column %d: found %q, expected %s", e.Line, e.Column, e.Found, e.Expected) + } + return fmt.Sprintf("parse error at line %d, column %d: %s", e.Line, e.Column, e.Message) +} + +// UnsupportedOperationError represents an unsupported operation. +// This is returned for operations that are not planned for implementation. +type UnsupportedOperationError struct { + Operation string +} + +func (e *UnsupportedOperationError) Error() string { + return fmt.Sprintf("unsupported operation: %s", e.Operation) +} + +// PlannedOperationError represents an operation that is planned but not yet implemented. +// When the caller receives this error, it should fallback to mongosh. +type PlannedOperationError struct { + Operation string +} + +func (e *PlannedOperationError) Error() string { + return fmt.Sprintf("operation %s is not yet implemented", e.Operation) +} + +// UnsupportedOptionError represents an unsupported option in a supported method. +type UnsupportedOptionError struct { + Method string + Option string +} + +func (e *UnsupportedOptionError) Error() string { + return fmt.Sprintf("unsupported option '%s' in %s", e.Option, e.Method) +} diff --git a/internal/translator/helpers.go b/internal/translator/helpers.go new file mode 100644 index 0000000..403d4cc --- /dev/null +++ b/internal/translator/helpers.go @@ -0,0 +1,183 @@ +package translator + +import ( + "fmt" + "strconv" + "strings" + + "github.com/bytebase/parser/mongodb" + "go.mongodb.org/mongo-driver/v2/bson" +) + +// unquoteString removes quotes from a string literal. +func unquoteString(s string) string { + if len(s) >= 2 { + if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') { + return s[1 : len(s)-1] + } + } + return s +} + +// convertValue converts a parsed value context to a Go value for BSON. +func convertValue(ctx mongodb.IValueContext) (any, error) { + switch v := ctx.(type) { + case *mongodb.DocumentValueContext: + return convertDocument(v.Document()) + case *mongodb.ArrayValueContext: + return convertArray(v.Array()) + case *mongodb.LiteralValueContext: + return convertLiteral(v.Literal()) + case *mongodb.HelperValueContext: + return convertHelperFunction(v.HelperFunction()) + case *mongodb.RegexLiteralValueContext: + return convertRegexLiteral(v.REGEX_LITERAL().GetText()) + case *mongodb.RegexpConstructorValueContext: + return convertRegExpConstructor(v.RegExpConstructor()) + default: + return nil, fmt.Errorf("unsupported value type: %T", ctx) + } +} + +// convertDocument converts a document context to bson.D. +func convertDocument(ctx mongodb.IDocumentContext) (bson.D, error) { + doc, ok := ctx.(*mongodb.DocumentContext) + if !ok { + return nil, fmt.Errorf("invalid document context") + } + + result := bson.D{} + for _, pair := range doc.AllPair() { + key, value, err := convertPair(pair) + if err != nil { + return nil, err + } + result = append(result, bson.E{Key: key, Value: value}) + } + return result, nil +} + +// convertPair converts a pair context to key-value. +func convertPair(ctx mongodb.IPairContext) (string, any, error) { + pair, ok := ctx.(*mongodb.PairContext) + if !ok { + return "", nil, fmt.Errorf("invalid pair context") + } + + key := extractKey(pair.Key()) + value, err := convertValue(pair.Value()) + if err != nil { + return "", nil, fmt.Errorf("error converting value for key %q: %w", key, err) + } + return key, value, nil +} + +// extractKey extracts the key string from a key context. +func extractKey(ctx mongodb.IKeyContext) string { + switch k := ctx.(type) { + case *mongodb.UnquotedKeyContext: + return k.Identifier().GetText() + case *mongodb.QuotedKeyContext: + return unquoteString(k.StringLiteral().GetText()) + default: + return "" + } +} + +// convertArray converts an array context to bson.A. +func convertArray(ctx mongodb.IArrayContext) (bson.A, error) { + arr, ok := ctx.(*mongodb.ArrayContext) + if !ok { + return nil, fmt.Errorf("invalid array context") + } + + result := bson.A{} + for _, val := range arr.AllValue() { + v, err := convertValue(val) + if err != nil { + return nil, err + } + result = append(result, v) + } + return result, nil +} + +// convertLiteral converts a literal context to a Go value. +func convertLiteral(ctx mongodb.ILiteralContext) (any, error) { + switch l := ctx.(type) { + case *mongodb.NumberLiteralContext: + return parseNumber(l.NUMBER().GetText()) + case *mongodb.StringLiteralValueContext: + return unquoteString(l.StringLiteral().GetText()), nil + case *mongodb.TrueLiteralContext: + return true, nil + case *mongodb.FalseLiteralContext: + return false, nil + case *mongodb.NullLiteralContext: + return nil, nil + default: + return nil, fmt.Errorf("unsupported literal type: %T", ctx) + } +} + +// parseNumber parses a number string to int32, int64, or float64. +func parseNumber(s string) (any, error) { + if strings.Contains(s, ".") || strings.Contains(s, "e") || strings.Contains(s, "E") { + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, fmt.Errorf("invalid number: %s", s) + } + return f, nil + } + + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid number: %s", s) + } + + if i >= -2147483648 && i <= 2147483647 { + return int32(i), nil + } + return i, nil +} + +// convertRegexLiteral converts a regex literal like /pattern/flags to bson.Regex. +func convertRegexLiteral(text string) (bson.Regex, error) { + if len(text) < 2 || text[0] != '/' { + return bson.Regex{}, fmt.Errorf("invalid regex literal: %s", text) + } + + lastSlash := strings.LastIndex(text, "/") + if lastSlash <= 0 { + return bson.Regex{}, fmt.Errorf("invalid regex literal: %s", text) + } + + pattern := text[1:lastSlash] + options := "" + if lastSlash < len(text)-1 { + options = text[lastSlash+1:] + } + + return bson.Regex{Pattern: pattern, Options: options}, nil +} + +// convertRegExpConstructor converts RegExp("pattern", "flags") to bson.Regex. +func convertRegExpConstructor(ctx mongodb.IRegExpConstructorContext) (bson.Regex, error) { + constructor, ok := ctx.(*mongodb.RegExpConstructorContext) + if !ok { + return bson.Regex{}, fmt.Errorf("invalid RegExp constructor context") + } + + strings := constructor.AllStringLiteral() + if len(strings) == 0 { + return bson.Regex{}, fmt.Errorf("RegExp requires at least a pattern argument") + } + + pattern := unquoteString(strings[0].GetText()) + options := "" + if len(strings) > 1 { + options = unquoteString(strings[1].GetText()) + } + + return bson.Regex{Pattern: pattern, Options: options}, nil +} diff --git a/internal/translator/method_registry.go b/internal/translator/method_registry.go new file mode 100644 index 0000000..6456ceb --- /dev/null +++ b/internal/translator/method_registry.go @@ -0,0 +1,90 @@ +package translator + +// methodStatus represents the support status of a MongoDB method. +type methodStatus int + +const ( + // statusPlanned means the method is planned for implementation (M2/M3). + // When encountered, the caller should fallback to mongosh. + statusPlanned methodStatus = iota +) + +// methodInfo contains metadata about a MongoDB method. +type methodInfo struct { + status methodStatus +} + +// methodRegistry contains only methods we plan to implement (M2, M3). +// If a method is NOT in this registry, it's unsupported (throw error, no fallback). +// If a method IS in this registry, it's planned (fallback to mongosh). +var methodRegistry = map[string]methodInfo{ + // ============================================================ + // MILESTONE 2: Write Operations (10 methods) + // ============================================================ + + // Insert Commands (2) + "collection:insertOne": {status: statusPlanned}, + "collection:insertMany": {status: statusPlanned}, + + // Update Commands (3) + "collection:updateOne": {status: statusPlanned}, + "collection:updateMany": {status: statusPlanned}, + "collection:replaceOne": {status: statusPlanned}, + + // Delete Commands (2) + "collection:deleteOne": {status: statusPlanned}, + "collection:deleteMany": {status: statusPlanned}, + + // Atomic Find-and-Modify Commands (3) + "collection:findOneAndUpdate": {status: statusPlanned}, + "collection:findOneAndReplace": {status: statusPlanned}, + "collection:findOneAndDelete": {status: statusPlanned}, + + // ============================================================ + // MILESTONE 3: Administrative Operations (22 methods) + // ============================================================ + + // Index Management (4) + "collection:createIndex": {status: statusPlanned}, + "collection:createIndexes": {status: statusPlanned}, + "collection:dropIndex": {status: statusPlanned}, + "collection:dropIndexes": {status: statusPlanned}, + + // Collection Management (4) + "database:createCollection": {status: statusPlanned}, + "collection:drop": {status: statusPlanned}, + "collection:renameCollection": {status: statusPlanned}, + "database:dropDatabase": {status: statusPlanned}, + + // Database Information (7) + "database:stats": {status: statusPlanned}, + "collection:stats": {status: statusPlanned}, + "database:serverStatus": {status: statusPlanned}, + "database:serverBuildInfo": {status: statusPlanned}, + "database:version": {status: statusPlanned}, + "database:hostInfo": {status: statusPlanned}, + "database:listCommands": {status: statusPlanned}, + + // Collection Information (7) + "collection:dataSize": {status: statusPlanned}, + "collection:storageSize": {status: statusPlanned}, + "collection:totalIndexSize": {status: statusPlanned}, + "collection:totalSize": {status: statusPlanned}, + "collection:isCapped": {status: statusPlanned}, + "collection:validate": {status: statusPlanned}, + "collection:latencyStats": {status: statusPlanned}, +} + +// IsPlannedMethod checks if a method is in the registry (planned for implementation). +// Returns true if the method should fallback to mongosh. +// Returns false if the method is unsupported (throw error). +func IsPlannedMethod(context, methodName string) bool { + key := context + ":" + methodName + _, ok := methodRegistry[key] + return ok +} + +// MethodRegistryStats returns statistics about the method registry. +func MethodRegistryStats() int { + return len(methodRegistry) +} diff --git a/internal/translator/translator.go b/internal/translator/translator.go new file mode 100644 index 0000000..fde46c9 --- /dev/null +++ b/internal/translator/translator.go @@ -0,0 +1,47 @@ +package translator + +import ( + "github.com/antlr4-go/antlr/v4" + "github.com/bytebase/parser/mongodb" +) + +// Parse parses a MongoDB shell statement and returns the operation. +func Parse(statement string) (*Operation, error) { + tree, parseErrors := parseMongoShell(statement) + if len(parseErrors) > 0 { + return nil, &ParseError{ + Line: parseErrors[0].Line, + Column: parseErrors[0].Column, + Message: parseErrors[0].Message, + } + } + + visitor := newVisitor() + visitor.Visit(tree) + if visitor.err != nil { + return nil, visitor.err + } + + return visitor.operation, nil +} + +// parseMongoShell parses a MongoDB shell statement and returns the parse tree. +func parseMongoShell(statement string) (mongodb.IProgramContext, []*mongodb.MongoShellParseError) { + is := antlr.NewInputStream(statement) + lexer := mongodb.NewMongoShellLexer(is) + + errorListener := mongodb.NewMongoShellErrorListener() + lexer.RemoveErrorListeners() + lexer.AddErrorListener(errorListener) + + stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) + parser := mongodb.NewMongoShellParser(stream) + + parser.RemoveErrorListeners() + parser.AddErrorListener(errorListener) + + parser.BuildParseTrees = true + tree := parser.Program() + + return tree, errorListener.Errors +} diff --git a/internal/translator/types.go b/internal/translator/types.go new file mode 100644 index 0000000..f44191d --- /dev/null +++ b/internal/translator/types.go @@ -0,0 +1,45 @@ +package translator + +import "go.mongodb.org/mongo-driver/v2/bson" + +// OperationType represents the type of MongoDB operation. +type OperationType int + +const ( + OpUnknown OperationType = iota + OpFind + OpFindOne + OpAggregate + OpShowDatabases + OpShowCollections + OpGetCollectionNames + OpGetCollectionInfos + OpGetIndexes + OpCountDocuments + OpEstimatedDocumentCount + OpDistinct +) + +// Operation represents a parsed MongoDB operation. +type Operation struct { + OpType OperationType + Collection string + Filter bson.D + // Read operation options (find, findOne) + Sort bson.D + Limit *int64 + Skip *int64 + Projection bson.D + // Index scan bounds and query options + Hint any // string (index name) or document (index spec) + Max bson.D // upper bound for index scan + Min bson.D // lower bound for index scan + MaxTimeMS *int64 // max execution time in milliseconds + // Aggregation pipeline + Pipeline bson.A + // distinct field name + DistinctField string + // getCollectionInfos options + NameOnly *bool + AuthorizedCollections *bool +} diff --git a/internal/translator/visitor.go b/internal/translator/visitor.go new file mode 100644 index 0000000..d78a4f5 --- /dev/null +++ b/internal/translator/visitor.go @@ -0,0 +1,284 @@ +package translator + +import ( + "strings" + + "github.com/antlr4-go/antlr/v4" + "github.com/bytebase/parser/mongodb" +) + +// visitor extracts operations from a parse tree. +type visitor struct { + mongodb.BaseMongoShellParserVisitor + operation *Operation + err error +} + +func newVisitor() *visitor { + return &visitor{ + operation: &Operation{OpType: OpUnknown}, + } +} + +func (v *visitor) Visit(tree antlr.ParseTree) any { + return tree.Accept(v) +} + +func (v *visitor) VisitProgram(ctx *mongodb.ProgramContext) any { + v.visitProgram(ctx) + return nil +} + +func (v *visitor) visitProgram(ctx mongodb.IProgramContext) { + for _, stmt := range ctx.AllStatement() { + v.visitStatement(stmt) + if v.err != nil { + return + } + } +} + +func (v *visitor) VisitStatement(ctx *mongodb.StatementContext) any { + v.visitStatement(ctx) + return nil +} + +func (v *visitor) visitStatement(ctx mongodb.IStatementContext) { + if ctx.DbStatement() != nil { + v.visitDbStatement(ctx.DbStatement()) + } else if ctx.ShellCommand() != nil { + v.visitShellCommand(ctx.ShellCommand()) + } +} + +func (v *visitor) visitDbStatement(ctx mongodb.IDbStatementContext) { + switch c := ctx.(type) { + case *mongodb.CollectionOperationContext: + v.visitCollectionOperation(c) + case *mongodb.GetCollectionNamesContext: + v.operation.OpType = OpGetCollectionNames + case *mongodb.GetCollectionInfosContext: + v.operation.OpType = OpGetCollectionInfos + v.extractGetCollectionInfosArgs(c) + } +} + +func (v *visitor) visitShellCommand(ctx mongodb.IShellCommandContext) { + switch ctx.(type) { + case *mongodb.ShowDatabasesContext: + v.operation.OpType = OpShowDatabases + case *mongodb.ShowCollectionsContext: + v.operation.OpType = OpShowCollections + default: + v.err = &UnsupportedOperationError{ + Operation: ctx.GetText(), + } + } +} + +func (v *visitor) VisitCollectionOperation(ctx *mongodb.CollectionOperationContext) any { + v.visitCollectionOperation(ctx) + return nil +} + +func (v *visitor) visitCollectionOperation(ctx *mongodb.CollectionOperationContext) { + v.operation.Collection = v.extractCollectionName(ctx.CollectionAccess()) + + if ctx.MethodChain() != nil { + v.visitMethodChain(ctx.MethodChain()) + } +} + +func (v *visitor) VisitGetCollectionNames(_ *mongodb.GetCollectionNamesContext) any { + v.operation.OpType = OpGetCollectionNames + return nil +} + +func (v *visitor) VisitGetCollectionInfos(ctx *mongodb.GetCollectionInfosContext) any { + v.operation.OpType = OpGetCollectionInfos + v.extractGetCollectionInfosArgs(ctx) + return nil +} + +func (v *visitor) extractCollectionName(ctx mongodb.ICollectionAccessContext) string { + switch c := ctx.(type) { + case *mongodb.DotAccessContext: + return c.Identifier().GetText() + case *mongodb.BracketAccessContext: + return unquoteString(c.StringLiteral().GetText()) + case *mongodb.GetCollectionAccessContext: + return unquoteString(c.StringLiteral().GetText()) + } + return "" +} + +func (v *visitor) visitMethodChain(ctx mongodb.IMethodChainContext) { + mc, ok := ctx.(*mongodb.MethodChainContext) + if !ok { + return + } + for _, methodCall := range mc.AllMethodCall() { + v.visitMethodCall(methodCall) + if v.err != nil { + return + } + } +} + +func (v *visitor) visitMethodCall(ctx mongodb.IMethodCallContext) { + mc, ok := ctx.(*mongodb.MethodCallContext) + if !ok { + return + } + + // Determine method context for registry lookup + getMethodContext := func() string { + if v.operation.OpType == OpFind || v.operation.OpType == OpFindOne { + return "cursor" + } + return "collection" + } + + switch { + // Supported read operations + case mc.FindMethod() != nil: + v.operation.OpType = OpFind + v.extractFindArgs(mc.FindMethod()) + case mc.FindOneMethod() != nil: + v.operation.OpType = OpFindOne + v.extractFindOneArgs(mc.FindOneMethod()) + case mc.CountDocumentsMethod() != nil: + v.operation.OpType = OpCountDocuments + v.extractCountDocumentsArgsFromMethod(mc.CountDocumentsMethod()) + case mc.EstimatedDocumentCountMethod() != nil: + v.operation.OpType = OpEstimatedDocumentCount + v.extractEstimatedDocumentCountArgs(mc.EstimatedDocumentCountMethod()) + case mc.DistinctMethod() != nil: + v.operation.OpType = OpDistinct + v.extractDistinctArgsFromMethod(mc.DistinctMethod()) + case mc.AggregateMethod() != nil: + v.operation.OpType = OpAggregate + v.extractAggregationPipelineFromMethod(mc.AggregateMethod()) + case mc.GetIndexesMethod() != nil: + v.operation.OpType = OpGetIndexes + + // Supported cursor modifiers + case mc.SortMethod() != nil: + v.extractSort(mc.SortMethod()) + case mc.LimitMethod() != nil: + v.extractLimit(mc.LimitMethod()) + case mc.SkipMethod() != nil: + v.extractSkip(mc.SkipMethod()) + case mc.ProjectionMethod() != nil: + v.extractProjection(mc.ProjectionMethod()) + case mc.HintMethod() != nil: + v.extractHint(mc.HintMethod()) + case mc.MaxMethod() != nil: + v.extractMax(mc.MaxMethod()) + case mc.MinMethod() != nil: + v.extractMin(mc.MinMethod()) + + // Planned M2 write operations - return PlannedOperationError for fallback + case mc.InsertOneMethod() != nil: + v.handleUnsupportedMethod("collection", "insertOne") + case mc.InsertManyMethod() != nil: + v.handleUnsupportedMethod("collection", "insertMany") + case mc.UpdateOneMethod() != nil: + v.handleUnsupportedMethod("collection", "updateOne") + case mc.UpdateManyMethod() != nil: + v.handleUnsupportedMethod("collection", "updateMany") + case mc.DeleteOneMethod() != nil: + v.handleUnsupportedMethod("collection", "deleteOne") + case mc.DeleteManyMethod() != nil: + v.handleUnsupportedMethod("collection", "deleteMany") + case mc.ReplaceOneMethod() != nil: + v.handleUnsupportedMethod("collection", "replaceOne") + case mc.FindOneAndUpdateMethod() != nil: + v.handleUnsupportedMethod("collection", "findOneAndUpdate") + case mc.FindOneAndReplaceMethod() != nil: + v.handleUnsupportedMethod("collection", "findOneAndReplace") + case mc.FindOneAndDeleteMethod() != nil: + v.handleUnsupportedMethod("collection", "findOneAndDelete") + + // Planned M3 index operations - return PlannedOperationError for fallback + case mc.CreateIndexMethod() != nil: + v.handleUnsupportedMethod("collection", "createIndex") + case mc.CreateIndexesMethod() != nil: + v.handleUnsupportedMethod("collection", "createIndexes") + case mc.DropIndexMethod() != nil: + v.handleUnsupportedMethod("collection", "dropIndex") + case mc.DropIndexesMethod() != nil: + v.handleUnsupportedMethod("collection", "dropIndexes") + + // Planned M3 collection management - return PlannedOperationError for fallback + case mc.DropMethod() != nil: + v.handleUnsupportedMethod("collection", "drop") + case mc.RenameCollectionMethod() != nil: + v.handleUnsupportedMethod("collection", "renameCollection") + + // Planned M3 stats operations - return PlannedOperationError for fallback + case mc.StatsMethod() != nil: + v.handleUnsupportedMethod("collection", "stats") + case mc.StorageSizeMethod() != nil: + v.handleUnsupportedMethod("collection", "storageSize") + case mc.TotalIndexSizeMethod() != nil: + v.handleUnsupportedMethod("collection", "totalIndexSize") + case mc.TotalSizeMethod() != nil: + v.handleUnsupportedMethod("collection", "totalSize") + case mc.DataSizeMethod() != nil: + v.handleUnsupportedMethod("collection", "dataSize") + case mc.IsCappedMethod() != nil: + v.handleUnsupportedMethod("collection", "isCapped") + case mc.ValidateMethod() != nil: + v.handleUnsupportedMethod("collection", "validate") + case mc.LatencyStatsMethod() != nil: + v.handleUnsupportedMethod("collection", "latencyStats") + + // Generic method fallback - all methods going through genericMethod are unsupported + case mc.GenericMethod() != nil: + gmCtx, ok := mc.GenericMethod().(*mongodb.GenericMethodContext) + if !ok { + return + } + methodName := gmCtx.Identifier().GetText() + v.handleUnsupportedMethod(getMethodContext(), methodName) + + // Default: all other methods not explicitly handled + // These go to handleUnsupportedMethod which returns UnsupportedOperationError + // since they're not in the planned registry + default: + // Extract method name from the parse tree for error message + methodName := v.extractMethodName(mc) + if methodName != "" { + v.handleUnsupportedMethod(getMethodContext(), methodName) + } + } +} + +// extractMethodName extracts the method name from a MethodCallContext for error messages. +func (v *visitor) extractMethodName(mc *mongodb.MethodCallContext) string { + // Try to get method name from various method contexts + // The parser creates specific method contexts for known methods + // For unknown methods, they go through GenericMethod which is handled separately + text := mc.GetText() + // Extract method name before the opening parenthesis + if idx := strings.Index(text, "("); idx > 0 { + return text[:idx] + } + return text +} + +// handleUnsupportedMethod checks the method registry and returns appropriate errors. +// If method is in registry (planned for M2/M3) -> PlannedOperationError (fallback to mongosh) +// If method is NOT in registry -> UnsupportedOperationError (no fallback) +func (v *visitor) handleUnsupportedMethod(context, methodName string) { + if IsPlannedMethod(context, methodName) { + v.err = &PlannedOperationError{ + Operation: methodName + "()", + } + return + } + v.err = &UnsupportedOperationError{ + Operation: methodName + "()", + } +} diff --git a/method_registry.go b/method_registry.go index aedf844..2382da1 100644 --- a/method_registry.go +++ b/method_registry.go @@ -1,90 +1,5 @@ package gomongo -// methodStatus represents the support status of a MongoDB method. -type methodStatus int - -const ( - // statusPlanned means the method is planned for implementation (M2/M3). - // When encountered, the caller should fallback to mongosh. - statusPlanned methodStatus = iota -) - -// methodInfo contains metadata about a MongoDB method. -type methodInfo struct { - status methodStatus -} - -// methodRegistry contains only methods we plan to implement (M2, M3). -// If a method is NOT in this registry, it's unsupported (throw error, no fallback). -// If a method IS in this registry, it's planned (fallback to mongosh). -var methodRegistry = map[string]methodInfo{ - // ============================================================ - // MILESTONE 2: Write Operations (10 methods) - // ============================================================ - - // Insert Commands (2) - "collection:insertOne": {status: statusPlanned}, - "collection:insertMany": {status: statusPlanned}, - - // Update Commands (3) - "collection:updateOne": {status: statusPlanned}, - "collection:updateMany": {status: statusPlanned}, - "collection:replaceOne": {status: statusPlanned}, - - // Delete Commands (2) - "collection:deleteOne": {status: statusPlanned}, - "collection:deleteMany": {status: statusPlanned}, - - // Atomic Find-and-Modify Commands (3) - "collection:findOneAndUpdate": {status: statusPlanned}, - "collection:findOneAndReplace": {status: statusPlanned}, - "collection:findOneAndDelete": {status: statusPlanned}, - - // ============================================================ - // MILESTONE 3: Administrative Operations (22 methods) - // ============================================================ - - // Index Management (4) - "collection:createIndex": {status: statusPlanned}, - "collection:createIndexes": {status: statusPlanned}, - "collection:dropIndex": {status: statusPlanned}, - "collection:dropIndexes": {status: statusPlanned}, - - // Collection Management (4) - "database:createCollection": {status: statusPlanned}, - "collection:drop": {status: statusPlanned}, - "collection:renameCollection": {status: statusPlanned}, - "database:dropDatabase": {status: statusPlanned}, - - // Database Information (7) - "database:stats": {status: statusPlanned}, - "collection:stats": {status: statusPlanned}, - "database:serverStatus": {status: statusPlanned}, - "database:serverBuildInfo": {status: statusPlanned}, - "database:version": {status: statusPlanned}, - "database:hostInfo": {status: statusPlanned}, - "database:listCommands": {status: statusPlanned}, - - // Collection Information (7) - "collection:dataSize": {status: statusPlanned}, - "collection:storageSize": {status: statusPlanned}, - "collection:totalIndexSize": {status: statusPlanned}, - "collection:totalSize": {status: statusPlanned}, - "collection:isCapped": {status: statusPlanned}, - "collection:validate": {status: statusPlanned}, - "collection:latencyStats": {status: statusPlanned}, -} - -// IsPlannedMethod checks if a method is in the registry (planned for implementation). -// Returns true if the method should fallback to mongosh. -// Returns false if the method is unsupported (throw error). -func IsPlannedMethod(context, methodName string) bool { - key := context + ":" + methodName - _, ok := methodRegistry[key] - return ok -} - -// MethodRegistryStats returns statistics about the method registry. -func MethodRegistryStats() int { - return len(methodRegistry) -} +// This file is kept for backward compatibility. +// The method registry is now maintained in internal/translator/method_registry.go. +// Public access to registry functions is provided via errors.go. diff --git a/translator.go b/translator.go deleted file mode 100644 index 5f679cf..0000000 --- a/translator.go +++ /dev/null @@ -1,1468 +0,0 @@ -package gomongo - -import ( - "fmt" - "strconv" - "strings" - - "github.com/antlr4-go/antlr/v4" - "github.com/bytebase/parser/mongodb" - "go.mongodb.org/mongo-driver/v2/bson" -) - -type operationType int - -const ( - opUnknown operationType = iota - opFind - opFindOne - opAggregate - opShowDatabases - opShowCollections - opGetCollectionNames - opGetCollectionInfos - opGetIndexes - opCountDocuments - opEstimatedDocumentCount - opDistinct -) - -// mongoOperation represents a parsed MongoDB operation. -type mongoOperation struct { - opType operationType - collection string - filter bson.D - // Read operation options (find, findOne) - sort bson.D - limit *int64 - skip *int64 - projection bson.D - // Index scan bounds and query options - hint any // string (index name) or document (index spec) - max bson.D // upper bound for index scan - min bson.D // lower bound for index scan - maxTimeMS *int64 // max execution time in milliseconds - // Aggregation pipeline - pipeline bson.A - // distinct field name - distinctField string - // getCollectionInfos options - nameOnly *bool - authorizedCollections *bool -} - -// mongoShellVisitor extracts operations from a parse tree. -type mongoShellVisitor struct { - mongodb.BaseMongoShellParserVisitor - operation *mongoOperation - err error -} - -func newMongoShellVisitor() *mongoShellVisitor { - return &mongoShellVisitor{ - operation: &mongoOperation{opType: opUnknown}, - } -} - -func (v *mongoShellVisitor) Visit(tree antlr.ParseTree) any { - return tree.Accept(v) -} - -func (v *mongoShellVisitor) VisitProgram(ctx *mongodb.ProgramContext) any { - v.visitProgram(ctx) - return nil -} - -func (v *mongoShellVisitor) visitProgram(ctx mongodb.IProgramContext) { - for _, stmt := range ctx.AllStatement() { - v.visitStatement(stmt) - if v.err != nil { - return - } - } -} - -func (v *mongoShellVisitor) VisitStatement(ctx *mongodb.StatementContext) any { - v.visitStatement(ctx) - return nil -} - -func (v *mongoShellVisitor) visitStatement(ctx mongodb.IStatementContext) { - if ctx.DbStatement() != nil { - v.visitDbStatement(ctx.DbStatement()) - } else if ctx.ShellCommand() != nil { - v.visitShellCommand(ctx.ShellCommand()) - } -} - -func (v *mongoShellVisitor) visitDbStatement(ctx mongodb.IDbStatementContext) { - switch c := ctx.(type) { - case *mongodb.CollectionOperationContext: - v.visitCollectionOperation(c) - case *mongodb.GetCollectionNamesContext: - v.operation.opType = opGetCollectionNames - case *mongodb.GetCollectionInfosContext: - v.operation.opType = opGetCollectionInfos - v.extractGetCollectionInfosArgs(c) - } -} - -func (v *mongoShellVisitor) visitShellCommand(ctx mongodb.IShellCommandContext) { - switch ctx.(type) { - case *mongodb.ShowDatabasesContext: - v.operation.opType = opShowDatabases - case *mongodb.ShowCollectionsContext: - v.operation.opType = opShowCollections - default: - v.err = &UnsupportedOperationError{ - Operation: ctx.GetText(), - } - } -} - -func (v *mongoShellVisitor) VisitCollectionOperation(ctx *mongodb.CollectionOperationContext) any { - v.visitCollectionOperation(ctx) - return nil -} - -func (v *mongoShellVisitor) visitCollectionOperation(ctx *mongodb.CollectionOperationContext) { - v.operation.collection = v.extractCollectionName(ctx.CollectionAccess()) - - if ctx.MethodChain() != nil { - v.visitMethodChain(ctx.MethodChain()) - } -} - -func (v *mongoShellVisitor) VisitGetCollectionNames(_ *mongodb.GetCollectionNamesContext) any { - v.operation.opType = opGetCollectionNames - return nil -} - -func (v *mongoShellVisitor) VisitGetCollectionInfos(ctx *mongodb.GetCollectionInfosContext) any { - v.operation.opType = opGetCollectionInfos - v.extractGetCollectionInfosArgs(ctx) - return nil -} - -func (v *mongoShellVisitor) extractGetCollectionInfosArgs(ctx *mongodb.GetCollectionInfosContext) { - args := ctx.Arguments() - if args == nil { - return - } - - argsCtx, ok := args.(*mongodb.ArgumentsContext) - if !ok { - return - } - - allArgs := argsCtx.AllArgument() - if len(allArgs) == 0 { - return - } - - // First argument is the filter (optional) - firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) - if !ok { - return - } - - valueCtx := firstArg.Value() - if valueCtx == nil { - return - } - - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("getCollectionInfos() filter must be a document") - return - } - - filter, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid filter: %w", err) - return - } - v.operation.filter = filter - - // Second argument is the options (optional) - if len(allArgs) >= 2 { - secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) - if !ok { - return - } - - optionsValueCtx := secondArg.Value() - if optionsValueCtx == nil { - return - } - - optionsDocValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("getCollectionInfos() options must be a document") - return - } - - optionsDoc, err := convertDocument(optionsDocValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - - for _, opt := range optionsDoc { - switch opt.Key { - case "nameOnly": - if val, ok := opt.Value.(bool); ok { - v.operation.nameOnly = &val - } else { - v.err = fmt.Errorf("getCollectionInfos() nameOnly must be a boolean") - return - } - case "authorizedCollections": - if val, ok := opt.Value.(bool); ok { - v.operation.authorizedCollections = &val - } else { - v.err = fmt.Errorf("getCollectionInfos() authorizedCollections must be a boolean") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "getCollectionInfos()", - Option: opt.Key, - } - return - } - } - } - - if len(allArgs) > 2 { - v.err = fmt.Errorf("getCollectionInfos() takes at most 2 arguments") - return - } -} - -// extractCountDocumentsArgsFromMethod extracts arguments from CountDocumentsMethodContext. -func (v *mongoShellVisitor) extractCountDocumentsArgsFromMethod(ctx mongodb.ICountDocumentsMethodContext) { - method, ok := ctx.(*mongodb.CountDocumentsMethodContext) - if !ok { - return - } - v.extractArgumentsForCountDocuments(method.Arguments()) -} - -// extractArgumentsForCountDocuments extracts countDocuments arguments from IArgumentsContext. -func (v *mongoShellVisitor) extractArgumentsForCountDocuments(args mongodb.IArgumentsContext) { - if args == nil { - return - } - - argsCtx, ok := args.(*mongodb.ArgumentsContext) - if !ok { - return - } - - allArgs := argsCtx.AllArgument() - if len(allArgs) == 0 { - return - } - - // First argument is the filter (optional) - firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) - if !ok { - return - } - - valueCtx := firstArg.Value() - if valueCtx == nil { - return - } - - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("countDocuments() filter must be a document") - return - } - - filter, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid filter: %w", err) - return - } - v.operation.filter = filter - - // Second argument is the options (optional) - if len(allArgs) < 2 { - return - } - - secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) - if !ok { - return - } - - optionsValueCtx := secondArg.Value() - if optionsValueCtx == nil { - return - } - - optionsDocValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("countDocuments() options must be a document") - return - } - - optionsDoc, err := convertDocument(optionsDocValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - - // Extract supported options: hint, limit, skip, maxTimeMS - for _, elem := range optionsDoc { - switch elem.Key { - case "hint": - v.operation.hint = elem.Value - case "limit": - if val, ok := elem.Value.(int32); ok { - limit := int64(val) - v.operation.limit = &limit - } else if val, ok := elem.Value.(int64); ok { - v.operation.limit = &val - } - case "skip": - if val, ok := elem.Value.(int32); ok { - skip := int64(val) - v.operation.skip = &skip - } else if val, ok := elem.Value.(int64); ok { - v.operation.skip = &val - } - case "maxTimeMS": - if val, ok := elem.Value.(int32); ok { - ms := int64(val) - v.operation.maxTimeMS = &ms - } else if val, ok := elem.Value.(int64); ok { - v.operation.maxTimeMS = &val - } else { - v.err = fmt.Errorf("countDocuments() maxTimeMS must be a number") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "countDocuments()", - Option: elem.Key, - } - return - } - } -} - -// extractEstimatedDocumentCountArgs extracts arguments from EstimatedDocumentCountMethodContext. -func (v *mongoShellVisitor) extractEstimatedDocumentCountArgs(ctx mongodb.IEstimatedDocumentCountMethodContext) { - method, ok := ctx.(*mongodb.EstimatedDocumentCountMethodContext) - if !ok { - return - } - - // EstimatedDocumentCountMethodContext has Argument() (singular) that returns a single optional argument - arg := method.Argument() - if arg == nil { - return - } - - argCtx, ok := arg.(*mongodb.ArgumentContext) - if !ok { - return - } - - valueCtx := argCtx.Value() - if valueCtx == nil { - return - } - - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("estimatedDocumentCount() options must be a document") - return - } - - options, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - - for _, opt := range options { - switch opt.Key { - case "maxTimeMS": - if val, ok := opt.Value.(int32); ok { - ms := int64(val) - v.operation.maxTimeMS = &ms - } else if val, ok := opt.Value.(int64); ok { - v.operation.maxTimeMS = &val - } else { - v.err = fmt.Errorf("estimatedDocumentCount() maxTimeMS must be a number") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "estimatedDocumentCount()", - Option: opt.Key, - } - return - } - } -} - -// extractDistinctArgsFromMethod extracts arguments from DistinctMethodContext. -func (v *mongoShellVisitor) extractDistinctArgsFromMethod(ctx mongodb.IDistinctMethodContext) { - method, ok := ctx.(*mongodb.DistinctMethodContext) - if !ok { - return - } - v.extractArgumentsForDistinct(method.Arguments()) -} - -// extractArgumentsForDistinct extracts distinct arguments from IArgumentsContext. -func (v *mongoShellVisitor) extractArgumentsForDistinct(args mongodb.IArgumentsContext) { - if args == nil { - v.err = fmt.Errorf("distinct() requires a field name argument") - return - } - - argsCtx, ok := args.(*mongodb.ArgumentsContext) - if !ok { - v.err = fmt.Errorf("distinct() requires a field name argument") - return - } - - allArgs := argsCtx.AllArgument() - if len(allArgs) == 0 { - v.err = fmt.Errorf("distinct() requires a field name argument") - return - } - - // First argument is the field name (required) - firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) - if !ok { - v.err = fmt.Errorf("distinct() requires a field name argument") - return - } - - valueCtx := firstArg.Value() - if valueCtx == nil { - v.err = fmt.Errorf("distinct() requires a field name argument") - return - } - - literalValue, ok := valueCtx.(*mongodb.LiteralValueContext) - if !ok { - v.err = fmt.Errorf("distinct() field name must be a string") - return - } - - stringLiteral, ok := literalValue.Literal().(*mongodb.StringLiteralValueContext) - if !ok { - v.err = fmt.Errorf("distinct() field name must be a string") - return - } - - v.operation.distinctField = unquoteString(stringLiteral.StringLiteral().GetText()) - - // Second argument is the filter (optional) - if len(allArgs) < 2 { - return - } - - secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) - if !ok { - return - } - - filterValueCtx := secondArg.Value() - if filterValueCtx == nil { - return - } - - docValue, ok := filterValueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("distinct() filter must be a document") - return - } - - filter, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid filter: %w", err) - return - } - v.operation.filter = filter - - // Third argument: options (optional) - if len(allArgs) >= 3 { - thirdArg, ok := allArgs[2].(*mongodb.ArgumentContext) - if !ok { - return - } - - optionsValueCtx := thirdArg.Value() - if optionsValueCtx == nil { - return - } - - docValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("distinct() options must be a document") - return - } - - options, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - - for _, opt := range options { - switch opt.Key { - case "maxTimeMS": - if val, ok := opt.Value.(int32); ok { - ms := int64(val) - v.operation.maxTimeMS = &ms - } else if val, ok := opt.Value.(int64); ok { - v.operation.maxTimeMS = &val - } else { - v.err = fmt.Errorf("distinct() maxTimeMS must be a number") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "distinct()", - Option: opt.Key, - } - return - } - } - } - - if len(allArgs) > 3 { - v.err = fmt.Errorf("distinct() takes at most 3 arguments") - return - } -} - -// extractAggregationPipelineFromMethod extracts pipeline from AggregateMethodContext. -func (v *mongoShellVisitor) extractAggregationPipelineFromMethod(ctx mongodb.IAggregateMethodContext) { - method, ok := ctx.(*mongodb.AggregateMethodContext) - if !ok { - return - } - v.extractArgumentsForAggregate(method.Arguments()) -} - -// extractArgumentsForAggregate extracts aggregate pipeline from IArgumentsContext. -func (v *mongoShellVisitor) extractArgumentsForAggregate(args mongodb.IArgumentsContext) { - if args == nil { - // Empty pipeline: aggregate() - v.operation.pipeline = bson.A{} - return - } - - argsCtx, ok := args.(*mongodb.ArgumentsContext) - if !ok { - v.err = fmt.Errorf("aggregate() requires an array argument") - return - } - - allArgs := argsCtx.AllArgument() - if len(allArgs) == 0 { - v.operation.pipeline = bson.A{} - return - } - - // First argument should be the pipeline array - firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) - if !ok { - v.err = fmt.Errorf("aggregate() requires an array argument") - return - } - - valueCtx := firstArg.Value() - if valueCtx == nil { - v.err = fmt.Errorf("aggregate() requires an array argument") - return - } - - arrayValue, ok := valueCtx.(*mongodb.ArrayValueContext) - if !ok { - v.err = fmt.Errorf("aggregate() requires an array argument, got %T", valueCtx) - return - } - - pipeline, err := convertArray(arrayValue.Array()) - if err != nil { - v.err = fmt.Errorf("invalid aggregation pipeline: %w", err) - return - } - - v.operation.pipeline = pipeline - - // Second argument: options (optional) - if len(allArgs) >= 2 { - secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) - if !ok { - return - } - optionsValueCtx := secondArg.Value() - if optionsValueCtx == nil { - return - } - docValue, ok := optionsValueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("aggregate() options must be a document") - return - } - options, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - for _, opt := range options { - switch opt.Key { - case "hint": - v.operation.hint = opt.Value - case "maxTimeMS": - if val, ok := opt.Value.(int32); ok { - ms := int64(val) - v.operation.maxTimeMS = &ms - } else if val, ok := opt.Value.(int64); ok { - v.operation.maxTimeMS = &val - } else { - v.err = fmt.Errorf("aggregate() maxTimeMS must be a number") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "aggregate()", - Option: opt.Key, - } - return - } - } - } - - // More than 2 arguments is an error - if len(allArgs) > 2 { - v.err = fmt.Errorf("aggregate() takes at most 2 arguments") - return - } -} - -func (v *mongoShellVisitor) extractCollectionName(ctx mongodb.ICollectionAccessContext) string { - switch c := ctx.(type) { - case *mongodb.DotAccessContext: - return c.Identifier().GetText() - case *mongodb.BracketAccessContext: - return unquoteString(c.StringLiteral().GetText()) - case *mongodb.GetCollectionAccessContext: - return unquoteString(c.StringLiteral().GetText()) - } - return "" -} - -func (v *mongoShellVisitor) visitMethodChain(ctx mongodb.IMethodChainContext) { - mc, ok := ctx.(*mongodb.MethodChainContext) - if !ok { - return - } - for _, methodCall := range mc.AllMethodCall() { - v.visitMethodCall(methodCall) - if v.err != nil { - return - } - } -} - -func (v *mongoShellVisitor) extractFindArgs(ctx mongodb.IFindMethodContext) { - fm, ok := ctx.(*mongodb.FindMethodContext) - if !ok { - return - } - - args := fm.Arguments() - if args == nil { - return - } - - argsCtx, ok := args.(*mongodb.ArgumentsContext) - if !ok { - return - } - - allArgs := argsCtx.AllArgument() - if len(allArgs) == 0 { - return - } - - // First argument: filter - firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) - if !ok { - return - } - valueCtx := firstArg.Value() - if valueCtx != nil { - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("find() filter must be a document") - return - } - filter, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid filter: %w", err) - return - } - v.operation.filter = filter - } - - // Second argument: projection (optional) - if len(allArgs) >= 2 { - secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) - if !ok { - return - } - valueCtx := secondArg.Value() - if valueCtx != nil { - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("find() projection must be a document") - return - } - projection, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid projection: %w", err) - return - } - v.operation.projection = projection - } - } - - // Third argument: options (optional) - if len(allArgs) >= 3 { - thirdArg, ok := allArgs[2].(*mongodb.ArgumentContext) - if !ok { - return - } - valueCtx := thirdArg.Value() - if valueCtx != nil { - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("find() options must be a document") - return - } - options, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - // Validate and extract supported options - for _, opt := range options { - switch opt.Key { - case "hint": - v.operation.hint = opt.Value - case "max": - if doc, ok := opt.Value.(bson.D); ok { - v.operation.max = doc - } else { - v.err = fmt.Errorf("find() max must be a document") - return - } - case "min": - if doc, ok := opt.Value.(bson.D); ok { - v.operation.min = doc - } else { - v.err = fmt.Errorf("find() min must be a document") - return - } - case "maxTimeMS": - if val, ok := opt.Value.(int32); ok { - ms := int64(val) - v.operation.maxTimeMS = &ms - } else if val, ok := opt.Value.(int64); ok { - v.operation.maxTimeMS = &val - } else { - v.err = fmt.Errorf("find() maxTimeMS must be a number") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "find()", - Option: opt.Key, - } - return - } - } - } - } - - // More than 3 arguments is an error - if len(allArgs) > 3 { - v.err = fmt.Errorf("find() takes at most 3 arguments") - return - } -} - -func (v *mongoShellVisitor) extractFindOneArgs(ctx mongodb.IFindOneMethodContext) { - fm, ok := ctx.(*mongodb.FindOneMethodContext) - if !ok { - return - } - - args := fm.Arguments() - if args == nil { - return - } - - argsCtx, ok := args.(*mongodb.ArgumentsContext) - if !ok { - return - } - - allArgs := argsCtx.AllArgument() - if len(allArgs) == 0 { - return - } - - // First argument: filter - firstArg, ok := allArgs[0].(*mongodb.ArgumentContext) - if !ok { - return - } - valueCtx := firstArg.Value() - if valueCtx != nil { - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("findOne() filter must be a document") - return - } - filter, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid filter: %w", err) - return - } - v.operation.filter = filter - } - - // Second argument: projection (optional) - if len(allArgs) >= 2 { - secondArg, ok := allArgs[1].(*mongodb.ArgumentContext) - if !ok { - return - } - valueCtx := secondArg.Value() - if valueCtx != nil { - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("findOne() projection must be a document") - return - } - projection, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid projection: %w", err) - return - } - v.operation.projection = projection - } - } - - // Third argument: options (optional) - if len(allArgs) >= 3 { - thirdArg, ok := allArgs[2].(*mongodb.ArgumentContext) - if !ok { - return - } - valueCtx := thirdArg.Value() - if valueCtx != nil { - docValue, ok := valueCtx.(*mongodb.DocumentValueContext) - if !ok { - v.err = fmt.Errorf("findOne() options must be a document") - return - } - options, err := convertDocument(docValue.Document()) - if err != nil { - v.err = fmt.Errorf("invalid options: %w", err) - return - } - // Validate and extract supported options - for _, opt := range options { - switch opt.Key { - case "hint": - v.operation.hint = opt.Value - case "max": - if doc, ok := opt.Value.(bson.D); ok { - v.operation.max = doc - } else { - v.err = fmt.Errorf("findOne() max must be a document") - return - } - case "min": - if doc, ok := opt.Value.(bson.D); ok { - v.operation.min = doc - } else { - v.err = fmt.Errorf("findOne() min must be a document") - return - } - case "maxTimeMS": - if val, ok := opt.Value.(int32); ok { - ms := int64(val) - v.operation.maxTimeMS = &ms - } else if val, ok := opt.Value.(int64); ok { - v.operation.maxTimeMS = &val - } else { - v.err = fmt.Errorf("findOne() maxTimeMS must be a number") - return - } - default: - v.err = &UnsupportedOptionError{ - Method: "findOne()", - Option: opt.Key, - } - return - } - } - } - } - - // More than 3 arguments is an error - if len(allArgs) > 3 { - v.err = fmt.Errorf("findOne() takes at most 3 arguments") - return - } -} - -func (v *mongoShellVisitor) extractSort(ctx mongodb.ISortMethodContext) { - sm, ok := ctx.(*mongodb.SortMethodContext) - if !ok { - return - } - - doc := sm.Document() - if doc == nil { - v.err = fmt.Errorf("sort() requires a document argument") - return - } - - sort, err := convertDocument(doc) - if err != nil { - v.err = fmt.Errorf("invalid sort: %w", err) - return - } - v.operation.sort = sort -} - -func (v *mongoShellVisitor) extractLimit(ctx mongodb.ILimitMethodContext) { - lm, ok := ctx.(*mongodb.LimitMethodContext) - if !ok { - return - } - - numNode := lm.NUMBER() - if numNode == nil { - v.err = fmt.Errorf("limit() requires a number argument") - return - } - - limit, err := strconv.ParseInt(numNode.GetText(), 10, 64) - if err != nil { - v.err = fmt.Errorf("invalid limit: %w", err) - return - } - v.operation.limit = &limit -} - -func (v *mongoShellVisitor) extractSkip(ctx mongodb.ISkipMethodContext) { - sm, ok := ctx.(*mongodb.SkipMethodContext) - if !ok { - return - } - - numNode := sm.NUMBER() - if numNode == nil { - v.err = fmt.Errorf("skip() requires a number argument") - return - } - - skip, err := strconv.ParseInt(numNode.GetText(), 10, 64) - if err != nil { - v.err = fmt.Errorf("invalid skip: %w", err) - return - } - v.operation.skip = &skip -} - -func (v *mongoShellVisitor) extractProjection(ctx mongodb.IProjectionMethodContext) { - pm, ok := ctx.(*mongodb.ProjectionMethodContext) - if !ok { - return - } - - doc := pm.Document() - if doc == nil { - v.err = fmt.Errorf("projection() requires a document argument") - return - } - - projection, err := convertDocument(doc) - if err != nil { - v.err = fmt.Errorf("invalid projection: %w", err) - return - } - v.operation.projection = projection -} - -func (v *mongoShellVisitor) extractHint(ctx mongodb.IHintMethodContext) { - hm, ok := ctx.(*mongodb.HintMethodContext) - if !ok { - return - } - - arg := hm.Argument() - if arg == nil { - v.err = fmt.Errorf("hint() requires an argument") - return - } - - argCtx, ok := arg.(*mongodb.ArgumentContext) - if !ok { - return - } - - valueCtx := argCtx.Value() - if valueCtx == nil { - return - } - - // hint can be a string (index name) or document (index spec) - switch val := valueCtx.(type) { - case *mongodb.LiteralValueContext: - strLit, ok := val.Literal().(*mongodb.StringLiteralValueContext) - if !ok { - v.err = fmt.Errorf("hint() argument must be a string or document") - return - } - v.operation.hint = unquoteString(strLit.StringLiteral().GetText()) - case *mongodb.DocumentValueContext: - doc, err := convertDocument(val.Document()) - if err != nil { - v.err = fmt.Errorf("invalid hint: %w", err) - return - } - v.operation.hint = doc - default: - v.err = fmt.Errorf("hint() argument must be a string or document") - } -} - -func (v *mongoShellVisitor) extractMax(ctx mongodb.IMaxMethodContext) { - mm, ok := ctx.(*mongodb.MaxMethodContext) - if !ok { - return - } - - doc := mm.Document() - if doc == nil { - v.err = fmt.Errorf("max() requires a document argument") - return - } - - maxDoc, err := convertDocument(doc) - if err != nil { - v.err = fmt.Errorf("invalid max: %w", err) - return - } - v.operation.max = maxDoc -} - -func (v *mongoShellVisitor) extractMin(ctx mongodb.IMinMethodContext) { - mm, ok := ctx.(*mongodb.MinMethodContext) - if !ok { - return - } - - doc := mm.Document() - if doc == nil { - v.err = fmt.Errorf("min() requires a document argument") - return - } - - minDoc, err := convertDocument(doc) - if err != nil { - v.err = fmt.Errorf("invalid min: %w", err) - return - } - v.operation.min = minDoc -} - -func (v *mongoShellVisitor) visitMethodCall(ctx mongodb.IMethodCallContext) { - mc, ok := ctx.(*mongodb.MethodCallContext) - if !ok { - return - } - - // Determine method context for registry lookup - getMethodContext := func() string { - if v.operation.opType == opFind || v.operation.opType == opFindOne { - return "cursor" - } - return "collection" - } - - switch { - // Supported read operations - case mc.FindMethod() != nil: - v.operation.opType = opFind - v.extractFindArgs(mc.FindMethod()) - case mc.FindOneMethod() != nil: - v.operation.opType = opFindOne - v.extractFindOneArgs(mc.FindOneMethod()) - case mc.CountDocumentsMethod() != nil: - v.operation.opType = opCountDocuments - v.extractCountDocumentsArgsFromMethod(mc.CountDocumentsMethod()) - case mc.EstimatedDocumentCountMethod() != nil: - v.operation.opType = opEstimatedDocumentCount - v.extractEstimatedDocumentCountArgs(mc.EstimatedDocumentCountMethod()) - case mc.DistinctMethod() != nil: - v.operation.opType = opDistinct - v.extractDistinctArgsFromMethod(mc.DistinctMethod()) - case mc.AggregateMethod() != nil: - v.operation.opType = opAggregate - v.extractAggregationPipelineFromMethod(mc.AggregateMethod()) - case mc.GetIndexesMethod() != nil: - v.operation.opType = opGetIndexes - - // Supported cursor modifiers - case mc.SortMethod() != nil: - v.extractSort(mc.SortMethod()) - case mc.LimitMethod() != nil: - v.extractLimit(mc.LimitMethod()) - case mc.SkipMethod() != nil: - v.extractSkip(mc.SkipMethod()) - case mc.ProjectionMethod() != nil: - v.extractProjection(mc.ProjectionMethod()) - case mc.HintMethod() != nil: - v.extractHint(mc.HintMethod()) - case mc.MaxMethod() != nil: - v.extractMax(mc.MaxMethod()) - case mc.MinMethod() != nil: - v.extractMin(mc.MinMethod()) - - // Planned M2 write operations - return PlannedOperationError for fallback - case mc.InsertOneMethod() != nil: - v.handleUnsupportedMethod("collection", "insertOne") - case mc.InsertManyMethod() != nil: - v.handleUnsupportedMethod("collection", "insertMany") - case mc.UpdateOneMethod() != nil: - v.handleUnsupportedMethod("collection", "updateOne") - case mc.UpdateManyMethod() != nil: - v.handleUnsupportedMethod("collection", "updateMany") - case mc.DeleteOneMethod() != nil: - v.handleUnsupportedMethod("collection", "deleteOne") - case mc.DeleteManyMethod() != nil: - v.handleUnsupportedMethod("collection", "deleteMany") - case mc.ReplaceOneMethod() != nil: - v.handleUnsupportedMethod("collection", "replaceOne") - case mc.FindOneAndUpdateMethod() != nil: - v.handleUnsupportedMethod("collection", "findOneAndUpdate") - case mc.FindOneAndReplaceMethod() != nil: - v.handleUnsupportedMethod("collection", "findOneAndReplace") - case mc.FindOneAndDeleteMethod() != nil: - v.handleUnsupportedMethod("collection", "findOneAndDelete") - - // Planned M3 index operations - return PlannedOperationError for fallback - case mc.CreateIndexMethod() != nil: - v.handleUnsupportedMethod("collection", "createIndex") - case mc.CreateIndexesMethod() != nil: - v.handleUnsupportedMethod("collection", "createIndexes") - case mc.DropIndexMethod() != nil: - v.handleUnsupportedMethod("collection", "dropIndex") - case mc.DropIndexesMethod() != nil: - v.handleUnsupportedMethod("collection", "dropIndexes") - - // Planned M3 collection management - return PlannedOperationError for fallback - case mc.DropMethod() != nil: - v.handleUnsupportedMethod("collection", "drop") - case mc.RenameCollectionMethod() != nil: - v.handleUnsupportedMethod("collection", "renameCollection") - - // Planned M3 stats operations - return PlannedOperationError for fallback - case mc.StatsMethod() != nil: - v.handleUnsupportedMethod("collection", "stats") - case mc.StorageSizeMethod() != nil: - v.handleUnsupportedMethod("collection", "storageSize") - case mc.TotalIndexSizeMethod() != nil: - v.handleUnsupportedMethod("collection", "totalIndexSize") - case mc.TotalSizeMethod() != nil: - v.handleUnsupportedMethod("collection", "totalSize") - case mc.DataSizeMethod() != nil: - v.handleUnsupportedMethod("collection", "dataSize") - case mc.IsCappedMethod() != nil: - v.handleUnsupportedMethod("collection", "isCapped") - case mc.ValidateMethod() != nil: - v.handleUnsupportedMethod("collection", "validate") - case mc.LatencyStatsMethod() != nil: - v.handleUnsupportedMethod("collection", "latencyStats") - - // Generic method fallback - all methods going through genericMethod are unsupported - case mc.GenericMethod() != nil: - gmCtx, ok := mc.GenericMethod().(*mongodb.GenericMethodContext) - if !ok { - return - } - methodName := gmCtx.Identifier().GetText() - v.handleUnsupportedMethod(getMethodContext(), methodName) - - // Default: all other methods not explicitly handled - // These go to handleUnsupportedMethod which returns UnsupportedOperationError - // since they're not in the planned registry - default: - // Extract method name from the parse tree for error message - methodName := v.extractMethodName(mc) - if methodName != "" { - v.handleUnsupportedMethod(getMethodContext(), methodName) - } - } -} - -// extractMethodName extracts the method name from a MethodCallContext for error messages. -func (v *mongoShellVisitor) extractMethodName(mc *mongodb.MethodCallContext) string { - // Try to get method name from various method contexts - // The parser creates specific method contexts for known methods - // For unknown methods, they go through GenericMethod which is handled separately - text := mc.GetText() - // Extract method name before the opening parenthesis - if idx := strings.Index(text, "("); idx > 0 { - return text[:idx] - } - return text -} - -// handleUnsupportedMethod checks the method registry and returns appropriate errors. -// If method is in registry (planned for M2/M3) -> PlannedOperationError (fallback to mongosh) -// If method is NOT in registry -> UnsupportedOperationError (no fallback) -func (v *mongoShellVisitor) handleUnsupportedMethod(context, methodName string) { - if IsPlannedMethod(context, methodName) { - v.err = &PlannedOperationError{ - Operation: methodName + "()", - } - return - } - v.err = &UnsupportedOperationError{ - Operation: methodName + "()", - } -} - -// unquoteString removes quotes from a string literal. -func unquoteString(s string) string { - if len(s) >= 2 { - if (s[0] == '"' && s[len(s)-1] == '"') || (s[0] == '\'' && s[len(s)-1] == '\'') { - return s[1 : len(s)-1] - } - } - return s -} - -// convertValue converts a parsed value context to a Go value for BSON. -func convertValue(ctx mongodb.IValueContext) (any, error) { - switch v := ctx.(type) { - case *mongodb.DocumentValueContext: - return convertDocument(v.Document()) - case *mongodb.ArrayValueContext: - return convertArray(v.Array()) - case *mongodb.LiteralValueContext: - return convertLiteral(v.Literal()) - case *mongodb.HelperValueContext: - return convertHelperFunction(v.HelperFunction()) - case *mongodb.RegexLiteralValueContext: - return convertRegexLiteral(v.REGEX_LITERAL().GetText()) - case *mongodb.RegexpConstructorValueContext: - return convertRegExpConstructor(v.RegExpConstructor()) - default: - return nil, fmt.Errorf("unsupported value type: %T", ctx) - } -} - -// convertDocument converts a document context to bson.D. -func convertDocument(ctx mongodb.IDocumentContext) (bson.D, error) { - doc, ok := ctx.(*mongodb.DocumentContext) - if !ok { - return nil, fmt.Errorf("invalid document context") - } - - result := bson.D{} - for _, pair := range doc.AllPair() { - key, value, err := convertPair(pair) - if err != nil { - return nil, err - } - result = append(result, bson.E{Key: key, Value: value}) - } - return result, nil -} - -// convertPair converts a pair context to key-value. -func convertPair(ctx mongodb.IPairContext) (string, any, error) { - pair, ok := ctx.(*mongodb.PairContext) - if !ok { - return "", nil, fmt.Errorf("invalid pair context") - } - - key := extractKey(pair.Key()) - value, err := convertValue(pair.Value()) - if err != nil { - return "", nil, fmt.Errorf("error converting value for key %q: %w", key, err) - } - return key, value, nil -} - -// extractKey extracts the key string from a key context. -func extractKey(ctx mongodb.IKeyContext) string { - switch k := ctx.(type) { - case *mongodb.UnquotedKeyContext: - return k.Identifier().GetText() - case *mongodb.QuotedKeyContext: - return unquoteString(k.StringLiteral().GetText()) - default: - return "" - } -} - -// convertArray converts an array context to bson.A. -func convertArray(ctx mongodb.IArrayContext) (bson.A, error) { - arr, ok := ctx.(*mongodb.ArrayContext) - if !ok { - return nil, fmt.Errorf("invalid array context") - } - - result := bson.A{} - for _, val := range arr.AllValue() { - v, err := convertValue(val) - if err != nil { - return nil, err - } - result = append(result, v) - } - return result, nil -} - -// convertLiteral converts a literal context to a Go value. -func convertLiteral(ctx mongodb.ILiteralContext) (any, error) { - switch l := ctx.(type) { - case *mongodb.NumberLiteralContext: - return parseNumber(l.NUMBER().GetText()) - case *mongodb.StringLiteralValueContext: - return unquoteString(l.StringLiteral().GetText()), nil - case *mongodb.TrueLiteralContext: - return true, nil - case *mongodb.FalseLiteralContext: - return false, nil - case *mongodb.NullLiteralContext: - return nil, nil - default: - return nil, fmt.Errorf("unsupported literal type: %T", ctx) - } -} - -// parseNumber parses a number string to int32, int64, or float64. -func parseNumber(s string) (any, error) { - if strings.Contains(s, ".") || strings.Contains(s, "e") || strings.Contains(s, "E") { - f, err := strconv.ParseFloat(s, 64) - if err != nil { - return nil, fmt.Errorf("invalid number: %s", s) - } - return f, nil - } - - i, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return nil, fmt.Errorf("invalid number: %s", s) - } - - if i >= -2147483648 && i <= 2147483647 { - return int32(i), nil - } - return i, nil -} - -// convertHelperFunction converts a helper function to a BSON value. -func convertHelperFunction(ctx mongodb.IHelperFunctionContext) (any, error) { - helper, ok := ctx.(*mongodb.HelperFunctionContext) - if !ok { - return nil, fmt.Errorf("invalid helper function context") - } - - if helper.ObjectIdHelper() != nil { - return convertObjectIdHelper(helper.ObjectIdHelper()) - } - if helper.IsoDateHelper() != nil { - return convertIsoDateHelper(helper.IsoDateHelper()) - } - if helper.DateHelper() != nil { - return convertDateHelper(helper.DateHelper()) - } - if helper.UuidHelper() != nil { - return convertUuidHelper(helper.UuidHelper()) - } - if helper.LongHelper() != nil { - return convertLongHelper(helper.LongHelper()) - } - if helper.Int32Helper() != nil { - return convertInt32Helper(helper.Int32Helper()) - } - if helper.DoubleHelper() != nil { - return convertDoubleHelper(helper.DoubleHelper()) - } - if helper.Decimal128Helper() != nil { - return convertDecimal128Helper(helper.Decimal128Helper()) - } - if helper.TimestampHelper() != nil { - return convertTimestampHelper(helper.TimestampHelper()) - } - - return nil, fmt.Errorf("unsupported helper function") -} - -// convertRegexLiteral converts a regex literal like /pattern/flags to bson.Regex. -func convertRegexLiteral(text string) (bson.Regex, error) { - if len(text) < 2 || text[0] != '/' { - return bson.Regex{}, fmt.Errorf("invalid regex literal: %s", text) - } - - lastSlash := strings.LastIndex(text, "/") - if lastSlash <= 0 { - return bson.Regex{}, fmt.Errorf("invalid regex literal: %s", text) - } - - pattern := text[1:lastSlash] - options := "" - if lastSlash < len(text)-1 { - options = text[lastSlash+1:] - } - - return bson.Regex{Pattern: pattern, Options: options}, nil -} - -// convertRegExpConstructor converts RegExp("pattern", "flags") to bson.Regex. -func convertRegExpConstructor(ctx mongodb.IRegExpConstructorContext) (bson.Regex, error) { - constructor, ok := ctx.(*mongodb.RegExpConstructorContext) - if !ok { - return bson.Regex{}, fmt.Errorf("invalid RegExp constructor context") - } - - strings := constructor.AllStringLiteral() - if len(strings) == 0 { - return bson.Regex{}, fmt.Errorf("RegExp requires at least a pattern argument") - } - - pattern := unquoteString(strings[0].GetText()) - options := "" - if len(strings) > 1 { - options = unquoteString(strings[1].GetText()) - } - - return bson.Regex{Pattern: pattern, Options: options}, nil -}