From 52310b9af47ab19719b02ee06a84c14936114f69 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 26 Nov 2019 13:27:54 -0800 Subject: [PATCH 1/4] models impl for tag stealing --- pkg/repositories/gormimpl/tag.go | 73 +++++++++++++++++++++++++-- pkg/repositories/gormimpl/tag_test.go | 6 +++ pkg/repositories/models/tag.go | 10 ++-- 3 files changed, 81 insertions(+), 8 deletions(-) diff --git a/pkg/repositories/gormimpl/tag.go b/pkg/repositories/gormimpl/tag.go index 2dd9c920..548b4d02 100644 --- a/pkg/repositories/gormimpl/tag.go +++ b/pkg/repositories/gormimpl/tag.go @@ -4,10 +4,12 @@ import ( "context" "github.com/jinzhu/gorm" + "github.com/lyft/datacatalog/pkg/common" "github.com/lyft/datacatalog/pkg/repositories/errors" "github.com/lyft/datacatalog/pkg/repositories/interfaces" "github.com/lyft/datacatalog/pkg/repositories/models" idl_datacatalog "github.com/lyft/datacatalog/protos/gen" + "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/promutils" ) @@ -25,14 +27,79 @@ func NewTagRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope pro } } +// A tag is associated with a single artifact for each partition combination +// When creating a tag, we remove the tag from any artifacts of the same partition +// Then add the tag to the new artifact func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { timer := h.repoMetrics.CreateDuration.Start(ctx) defer timer.Stop() - db := h.db.Create(&tag) + tx := h.db.Begin() - if db.Error != nil { - return h.errorTransformer.ToDataCatalogError(db.Error) + var artifactToTag models.Artifact + tx = tx.Preload("Partitions").Find(&artifactToTag, models.Artifact{ + ArtifactKey: models.ArtifactKey{ArtifactID: tag.ArtifactID}, + }) + + // List artifacts with the same partitions and tag + filters := make([]models.ModelValueFilter, 0, len(artifactToTag.Partitions)*2+1) + for _, partition := range artifactToTag.Partitions { + filters = append(filters, NewGormValueFilter(common.Partition, common.Equal, "key", partition.Key)) + filters = append(filters, NewGormValueFilter(common.Partition, common.Equal, "value", partition.Value)) + } + + filters = append(filters, NewGormValueFilter(common.Artifact, common.Equal, "tag_name", tag.TagName)) + + listTaggedArtifacts := models.ListModelsInput{ + JoinEntityToConditionMap: map[common.Entity]models.ModelJoinCondition{ + common.Tag: NewGormJoinCondition(common.Artifact, common.Tag), + common.Partition: NewGormJoinCondition(common.Artifact, common.Partition), + }, + Filters: filters, + } + + tx, err := applyListModelsInput(tx, common.Artifact, listTaggedArtifacts) + if err != nil { + tx.Rollback() + return err + } + + var artifacts []models.Artifact + tx = tx.Find(&artifacts) + if tx.Error != nil { + logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) + } + + if len(artifacts) != 0 { + // Soft-delete the existing tags on the artifacts that are tagged by this tag in the partition + oldTags := make([]models.Tag, 0, len(artifacts)) + for _, artifact := range artifacts { + oldTags = append(oldTags, models.Tag{ + TagKey: models.TagKey{TagName: tag.TagName}, + ArtifactID: artifact.ArtifactID, + }) + } + tx = tx.Delete(&models.Tag{}, oldTags) + } + + // Check if the artifact was ever previously tagged with this tag, if so undelete the record + var previouslyTagged *models.Artifact + tx.Unscoped().Find(previouslyTagged, tag) + if previouslyTagged != nil { + previouslyTagged.DeletedAt = nil + tx = tx.Update(previouslyTagged) + } else { + // Tag the new artifact + tx = tx.Create(&tag) + } + + tx = tx.Commit() + if tx.Error != nil { + logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) } return nil } diff --git a/pkg/repositories/gormimpl/tag_test.go b/pkg/repositories/gormimpl/tag_test.go index 41f172e6..fb81352a 100644 --- a/pkg/repositories/gormimpl/tag_test.go +++ b/pkg/repositories/gormimpl/tag_test.go @@ -50,6 +50,12 @@ func TestCreateTag(t *testing.T) { GlobalMock.Logging = true // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(getTestArtifact())) + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBArtifactResponse(getTestArtifact())) + GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( func(s string, values []driver.NamedValue) { diff --git a/pkg/repositories/models/tag.go b/pkg/repositories/models/tag.go index 057f78c3..a0ce3f2f 100644 --- a/pkg/repositories/models/tag.go +++ b/pkg/repositories/models/tag.go @@ -1,17 +1,17 @@ package models type TagKey struct { - DatasetProject string `gorm:"primary_key"` - DatasetName string `gorm:"primary_key"` - DatasetDomain string `gorm:"primary_key"` - DatasetVersion string `gorm:"primary_key"` + DatasetProject string + DatasetName string + DatasetDomain string + DatasetVersion string TagName string `gorm:"primary_key"` } type Tag struct { BaseModel TagKey - ArtifactID string + ArtifactID string `gorm:"primary_key"` DatasetUUID string `gorm:"type:uuid;index:tags_dataset_uuid_idx"` Artifact Artifact `gorm:"association_foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID;foreignkey:DatasetProject,DatasetName,DatasetDomain,DatasetVersion,ArtifactID"` } From 5b581657306c84cc24d02d363d28c4a2d6068bd1 Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 10 Dec 2019 10:23:10 -0800 Subject: [PATCH 2/4] tag stealing in tx --- pkg/repositories/gormimpl/tag.go | 50 ++++++++++++++------------- pkg/repositories/gormimpl/tag_test.go | 9 +++-- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/pkg/repositories/gormimpl/tag.go b/pkg/repositories/gormimpl/tag.go index 548b4d02..bc5459af 100644 --- a/pkg/repositories/gormimpl/tag.go +++ b/pkg/repositories/gormimpl/tag.go @@ -50,7 +50,7 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { filters = append(filters, NewGormValueFilter(common.Artifact, common.Equal, "tag_name", tag.TagName)) - listTaggedArtifacts := models.ListModelsInput{ + listTaggedInput := models.ListModelsInput{ JoinEntityToConditionMap: map[common.Entity]models.ModelJoinCondition{ common.Tag: NewGormJoinCondition(common.Artifact, common.Tag), common.Partition: NewGormJoinCondition(common.Artifact, common.Partition), @@ -58,7 +58,7 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { Filters: filters, } - tx, err := applyListModelsInput(tx, common.Artifact, listTaggedArtifacts) + tx, err := applyListModelsInput(tx, common.Artifact, listTaggedInput) if err != nil { tx.Rollback() return err @@ -72,28 +72,30 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { return h.errorTransformer.ToDataCatalogError(tx.Error) } - if len(artifacts) != 0 { - // Soft-delete the existing tags on the artifacts that are tagged by this tag in the partition - oldTags := make([]models.Tag, 0, len(artifacts)) - for _, artifact := range artifacts { - oldTags = append(oldTags, models.Tag{ - TagKey: models.TagKey{TagName: tag.TagName}, - ArtifactID: artifact.ArtifactID, - }) - } - tx = tx.Delete(&models.Tag{}, oldTags) - } - - // Check if the artifact was ever previously tagged with this tag, if so undelete the record - var previouslyTagged *models.Artifact - tx.Unscoped().Find(previouslyTagged, tag) - if previouslyTagged != nil { - previouslyTagged.DeletedAt = nil - tx = tx.Update(previouslyTagged) - } else { - // Tag the new artifact - tx = tx.Create(&tag) - } + // if len(artifacts) != 0 { + // // Soft-delete the existing tags on the artifacts that are tagged by this tag in the partition + // for _, artifact := range artifacts { + // oldTag := models.Tag{ + // TagKey: models.TagKey{TagName: tag.TagName}, + // ArtifactID: artifact.ArtifactID, + // DatasetUUID: artifact.DatasetUUID, + // } + // tx = tx.Where(oldTag).Delete(&models.Tag{}) + // } + // } + + // If the artifact was ever previously tagged with this tag, we need to + // undelete the record because we cannot tag the artifact again since + // the primary keys are the same. + // var previouslyTagged *models.Artifact + // tx = tx.Unscoped().Find(previouslyTagged, tag) // unscope will ignore deletedAt + // if previouslyTagged != nil { + // previouslyTagged.DeletedAt = nil + // tx = tx.Update(previouslyTagged) + // } else { + // // Tag the new artifact + // tx = tx.Create(&tag) + // } tx = tx.Commit() if tx.Error != nil { diff --git a/pkg/repositories/gormimpl/tag_test.go b/pkg/repositories/gormimpl/tag_test.go index fb81352a..08adf569 100644 --- a/pkg/repositories/gormimpl/tag_test.go +++ b/pkg/repositories/gormimpl/tag_test.go @@ -49,12 +49,17 @@ func TestCreateTag(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true + oldArtifact := getTestArtifact() + // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(getTestArtifact())) + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(oldArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBPartitionResponse(oldArtifact)) GlobalMock.NewMock().WithQuery( - `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBArtifactResponse(getTestArtifact())) + `SELECT "artifacts".* FROM "artifacts" JOIN tags ON artifacts.artifact_id = tags.artifact_id JOIN partitions ON artifacts.artifact_id = partitions.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions.key = region) AND (partitions.value = SEA) AND (artifacts.tag_name = test-tagname)) LIMIT 0 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact)) GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( From bcd89910d82c1a986fce11ed617d8398339b97db Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Thu, 13 Feb 2020 15:12:35 -0800 Subject: [PATCH 3/4] Add tag stealing and tests --- pkg/repositories/gormimpl/artifact_test.go | 2 +- pkg/repositories/gormimpl/tag.go | 116 +++++++++++++-------- pkg/repositories/gormimpl/tag_test.go | 52 +++++++-- 3 files changed, 122 insertions(+), 48 deletions(-) diff --git a/pkg/repositories/gormimpl/artifact_test.go b/pkg/repositories/gormimpl/artifact_test.go index e4155d57..f77ba0e9 100644 --- a/pkg/repositories/gormimpl/artifact_test.go +++ b/pkg/repositories/gormimpl/artifact_test.go @@ -181,7 +181,7 @@ func TestGetArtifact(t *testing.T) { GlobalMock.NewMock().WithQuery( `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123))) ORDER BY partitions.created_at ASC,"partitions"."dataset_uuid" ASC`).WithReply(expectedPartitionResponse) GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND ((("artifact_id","dataset_uuid") IN ((123,test-uuid)))) ORDER BY "tags"."dataset_project" ASC`).WithReply(expectedTagResponse) + `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND ((("artifact_id","dataset_uuid") IN ((123,test-uuid)))) ORDER BY "tags"."tag_name" ASC`).WithReply(expectedTagResponse) getInput := models.ArtifactKey{ DatasetProject: artifact.DatasetProject, DatasetDomain: artifact.DatasetDomain, diff --git a/pkg/repositories/gormimpl/tag.go b/pkg/repositories/gormimpl/tag.go index bc5459af..b77fc3ca 100644 --- a/pkg/repositories/gormimpl/tag.go +++ b/pkg/repositories/gormimpl/tag.go @@ -34,72 +34,106 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { timer := h.repoMetrics.CreateDuration.Start(ctx) defer timer.Stop() + // There are several steps that need to be done in a transaction in order for tag stealing to occur tx := h.db.Begin() + defer func() { + if r := recover(); r != nil { + tx.Rollback() + } + }() + // 1. Find the set of partitions this artifact belongs to var artifactToTag models.Artifact - tx = tx.Preload("Partitions").Find(&artifactToTag, models.Artifact{ + tx.Preload("Partitions").Find(&artifactToTag, models.Artifact{ ArtifactKey: models.ArtifactKey{ArtifactID: tag.ArtifactID}, }) - // List artifacts with the same partitions and tag - filters := make([]models.ModelValueFilter, 0, len(artifactToTag.Partitions)*2+1) + // 2. List artifacts in the partitions that are currently tagged + modelFilters := make([]models.ModelFilter, 0, len(artifactToTag.Partitions)+2) for _, partition := range artifactToTag.Partitions { - filters = append(filters, NewGormValueFilter(common.Partition, common.Equal, "key", partition.Key)) - filters = append(filters, NewGormValueFilter(common.Partition, common.Equal, "value", partition.Value)) + modelFilters = append(modelFilters, models.ModelFilter{ + Entity: common.Partition, + ValueFilters: []models.ModelValueFilter{ + NewGormValueFilter(common.Equal, "key", partition.Key), + NewGormValueFilter(common.Equal, "value", partition.Value), + }, + JoinCondition: NewGormJoinCondition(common.Artifact, common.Partition), + }) } - filters = append(filters, NewGormValueFilter(common.Artifact, common.Equal, "tag_name", tag.TagName)) + modelFilters = append(modelFilters, models.ModelFilter{ + Entity: common.Tag, + ValueFilters: []models.ModelValueFilter{ + NewGormValueFilter(common.Equal, "tag_name", tag.TagName), + NewGormValueFilter(common.Equal, "deleted_at", gorm.Expr("NULL")), // AC: this may not work, may have to specially handle nil + }, + JoinCondition: NewGormJoinCondition(common.Artifact, common.Tag), + }) listTaggedInput := models.ListModelsInput{ - JoinEntityToConditionMap: map[common.Entity]models.ModelJoinCondition{ - common.Tag: NewGormJoinCondition(common.Artifact, common.Tag), - common.Partition: NewGormJoinCondition(common.Artifact, common.Partition), - }, - Filters: filters, + ModelFilters: modelFilters, + Limit: 100, } - tx, err := applyListModelsInput(tx, common.Artifact, listTaggedInput) + listArtifactsScope, err := applyListModelsInput(tx, common.Artifact, listTaggedInput) if err != nil { + logger.Errorf(ctx, "Unable to construct artiact list, rolling back, tag: [%v], err [%v]", tag, tx.Error) tx.Rollback() - return err + return h.errorTransformer.ToDataCatalogError(err) + } var artifacts []models.Artifact - tx = tx.Find(&artifacts) - if tx.Error != nil { - logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, tx.Error) + if listArtifactsScope.Find(&artifacts).Error != nil { + logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, listArtifactsScope.Error) tx.Rollback() - return h.errorTransformer.ToDataCatalogError(tx.Error) + return h.errorTransformer.ToDataCatalogError(listArtifactsScope.Error) } - // if len(artifacts) != 0 { - // // Soft-delete the existing tags on the artifacts that are tagged by this tag in the partition - // for _, artifact := range artifacts { - // oldTag := models.Tag{ - // TagKey: models.TagKey{TagName: tag.TagName}, - // ArtifactID: artifact.ArtifactID, - // DatasetUUID: artifact.DatasetUUID, - // } - // tx = tx.Where(oldTag).Delete(&models.Tag{}) - // } - // } - - // If the artifact was ever previously tagged with this tag, we need to - // undelete the record because we cannot tag the artifact again since + // 3. Remove the tags from the currently tagged artifacts + if len(artifacts) != 0 { + // Soft-delete the existing tags on the artifacts that are currently tagged + for _, artifact := range artifacts { + + // if the artifact to tag is already tagged, no need to remove it + if artifactToTag.ArtifactID != artifact.ArtifactID { + oldTag := models.Tag{ + TagKey: models.TagKey{TagName: tag.TagName}, + ArtifactID: artifact.ArtifactID, + DatasetUUID: artifact.DatasetUUID, + } + deleteScope := tx.NewScope(&models.Tag{}).DB().Delete(&models.Tag{}, oldTag) + if deleteScope.Error != nil { + logger.Errorf(ctx, "Unable to delete previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, deleteScope.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(deleteScope.Error) + } + } + } + } + + // 4. If the artifact was ever previously tagged with this tag, we need to + // un-delete the record because we cannot tag the artifact again since // the primary keys are the same. - // var previouslyTagged *models.Artifact - // tx = tx.Unscoped().Find(previouslyTagged, tag) // unscope will ignore deletedAt - // if previouslyTagged != nil { - // previouslyTagged.DeletedAt = nil - // tx = tx.Update(previouslyTagged) - // } else { - // // Tag the new artifact - // tx = tx.Create(&tag) - // } + undeleteScope := tx.Unscoped().Model(&tag).Update("deleted_at", gorm.Expr("NULL")) // unscope will ignore deletedAt + if undeleteScope.Error != nil { + logger.Errorf(ctx, "Unable to undelete tag tag, rolling back, tag: [%v], err [%v]", tag, tx.Error) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(tx.Error) + } + + // 5. Tag the new artifact + if undeleteScope.RowsAffected == 0 { + if err := tx.Create(&tag).Error; err != nil { + logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, err) + tx.Rollback() + return h.errorTransformer.ToDataCatalogError(err) + } + } tx = tx.Commit() if tx.Error != nil { - logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, tx.Error) + logger.Errorf(ctx, "Unable to commit transaction, rolling back, tag: [%v], err [%v]", tag, tx.Error) tx.Rollback() return h.errorTransformer.ToDataCatalogError(tx.Error) } diff --git a/pkg/repositories/gormimpl/tag_test.go b/pkg/repositories/gormimpl/tag_test.go index 08adf569..e0cd9169 100644 --- a/pkg/repositories/gormimpl/tag_test.go +++ b/pkg/repositories/gormimpl/tag_test.go @@ -44,22 +44,58 @@ func getTestTag() models.Tag { } } -func TestCreateTag(t *testing.T) { +func TestCreateTagNew(t *testing.T) { + tagCreated := false + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + newArtifact := getTestArtifact() + + // Only match on queries that append expected filters + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(newArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBPartitionResponse(newArtifact)) + + GlobalMock.NewMock().WithQuery( + `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at = NULL)) LIMIT 100 OFFSET 0`).WithReply([]map[string]interface{}{}) + + GlobalMock.NewMock().WithQuery( + `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( + func(s string, values []driver.NamedValue) { + tagCreated = true + }, + ) + + newTag := getTestTag() + newTag.ArtifactID = newArtifact.ArtifactID + + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) + err := tagRepo.Create(context.Background(), newTag) + + assert.NoError(t, err) + assert.True(t, tagCreated) +} + +func TestStealOldTag(t *testing.T) { tagCreated := false GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true oldArtifact := getTestArtifact() + newArtifact := getTestArtifact() + newArtifact.ArtifactID = "111" // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 123))`).WithReply(getDBArtifactResponse(oldArtifact)) + `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND (("artifacts"."artifact_id" = 111))`).WithReply(getDBArtifactResponse(newArtifact)) GlobalMock.NewMock().WithQuery( - `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBPartitionResponse(oldArtifact)) + `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (111)))`).WithReply(getDBPartitionResponse(newArtifact)) GlobalMock.NewMock().WithQuery( - `SELECT "artifacts".* FROM "artifacts" JOIN tags ON artifacts.artifact_id = tags.artifact_id JOIN partitions ON artifacts.artifact_id = partitions.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions.key = region) AND (partitions.value = SEA) AND (artifacts.tag_name = test-tagname)) LIMIT 0 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact)) + `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at = NULL)) LIMIT 100 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact)) GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( @@ -68,8 +104,12 @@ func TestCreateTag(t *testing.T) { }, ) + newTag := getTestTag() + newTag.ArtifactID = newArtifact.ArtifactID + tagRepo := NewTagRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope()) - err := tagRepo.Create(context.Background(), getTestTag()) + err := tagRepo.Create(context.Background(), newTag) + assert.NoError(t, err) assert.True(t, tagCreated) } @@ -82,7 +122,7 @@ func TestGetTag(t *testing.T) { // Only match on queries that append expected filters GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag)) ORDER BY tags.created_at DESC,"tags"."dataset_project" ASC LIMIT 1`).WithReply(getDBTagResponse(artifact)) + `SELECT * FROM "tags" WHERE "tags"."deleted_at" IS NULL AND (("tags"."dataset_project" = testProject) AND ("tags"."dataset_name" = testName) AND ("tags"."dataset_domain" = testDomain) AND ("tags"."dataset_version" = testVersion) AND ("tags"."tag_name" = test-tag)) ORDER BY tags.created_at DESC,"tags"."tag_name" ASC LIMIT 1`).WithReply(getDBTagResponse(artifact)) GlobalMock.NewMock().WithQuery( `SELECT * FROM "artifacts" WHERE "artifacts"."deleted_at" IS NULL AND ((("dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id") IN ((testProject,testName,testDomain,testVersion,123))))`).WithReply(getDBArtifactResponse(artifact)) GlobalMock.NewMock().WithQuery( From b4162238f7dbfb48397fa12758a206b3a1484b8f Mon Sep 17 00:00:00 2001 From: Andrew Chan Date: Tue, 18 Feb 2020 18:24:31 -0800 Subject: [PATCH 4/4] Correct previously tagged artifacts --- pkg/common/filters.go | 1 + pkg/repositories/factory.go | 1 + pkg/repositories/gormimpl/filter.go | 16 +++++++++++++++- pkg/repositories/gormimpl/list.go | 7 ++++++- pkg/repositories/gormimpl/tag.go | 10 +++++----- pkg/repositories/gormimpl/tag_test.go | 4 ++-- 6 files changed, 30 insertions(+), 9 deletions(-) diff --git a/pkg/common/filters.go b/pkg/common/filters.go index 58ca7916..bfb5f950 100644 --- a/pkg/common/filters.go +++ b/pkg/common/filters.go @@ -21,5 +21,6 @@ type ComparisonOperator int const ( Equal ComparisonOperator = iota + IsNull // Add more operators as needed, ie., gte, lte ) diff --git a/pkg/repositories/factory.go b/pkg/repositories/factory.go index 535b091d..b5e062d1 100644 --- a/pkg/repositories/factory.go +++ b/pkg/repositories/factory.go @@ -35,6 +35,7 @@ func GetRepository(repoType RepoConfig, dbConfig config.DbConfig, scope promutil if err != nil { panic(err) } + db.LogMode(true) return NewPostgresRepo( db, errors.NewPostgresErrorTransformer(), diff --git a/pkg/repositories/gormimpl/filter.go b/pkg/repositories/gormimpl/filter.go index f636e3d1..c5309606 100644 --- a/pkg/repositories/gormimpl/filter.go +++ b/pkg/repositories/gormimpl/filter.go @@ -10,7 +10,8 @@ import ( // String formats for various GORM expression queries const ( - equalQuery = "%s.%s = ?" + equalQuery = "%s.%s = ?" + isNullQuery = "%s.%s IS NULL" ) type gormValueFilterImpl struct { @@ -27,7 +28,13 @@ func (g *gormValueFilterImpl) GetDBQueryExpression(tableName string) (models.DBQ Query: fmt.Sprintf(equalQuery, tableName, g.field), Args: g.value, }, nil + case common.IsNull: + return models.DBQueryExpr{ + Query: fmt.Sprintf(isNullQuery, tableName, g.field), + Args: g.value, + }, nil } + return models.DBQueryExpr{}, errors.GetUnsupportedFilterExpressionErr(g.comparisonOperator) } @@ -39,3 +46,10 @@ func NewGormValueFilter(comparisonOperator common.ComparisonOperator, field stri value: value, } } + +func NewGormNullFilter(field string) models.ModelValueFilter { + return &gormValueFilterImpl{ + comparisonOperator: common.IsNull, + field: field, + } +} diff --git a/pkg/repositories/gormimpl/list.go b/pkg/repositories/gormimpl/list.go index 7e461df5..6cdcedf7 100644 --- a/pkg/repositories/gormimpl/list.go +++ b/pkg/repositories/gormimpl/list.go @@ -56,7 +56,12 @@ func applyListModelsInput(tx *gorm.DB, sourceEntity common.Entity, in models.Lis if err != nil { return nil, err } - tx = tx.Where(dbQueryExpr.Query, dbQueryExpr.Args) + + if dbQueryExpr.Args == nil { + tx = tx.Where(dbQueryExpr.Query) + } else { + tx = tx.Where(dbQueryExpr.Query, dbQueryExpr.Args) + } } } diff --git a/pkg/repositories/gormimpl/tag.go b/pkg/repositories/gormimpl/tag.go index b77fc3ca..34be53b6 100644 --- a/pkg/repositories/gormimpl/tag.go +++ b/pkg/repositories/gormimpl/tag.go @@ -65,7 +65,7 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { Entity: common.Tag, ValueFilters: []models.ModelValueFilter{ NewGormValueFilter(common.Equal, "tag_name", tag.TagName), - NewGormValueFilter(common.Equal, "deleted_at", gorm.Expr("NULL")), // AC: this may not work, may have to specially handle nil + NewGormNullFilter("deleted_at"), }, JoinCondition: NewGormJoinCondition(common.Artifact, common.Tag), }) @@ -84,10 +84,10 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { } var artifacts []models.Artifact - if listArtifactsScope.Find(&artifacts).Error != nil { - logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, listArtifactsScope.Error) + if err := listArtifactsScope.Find(&artifacts).Error; err != nil { + logger.Errorf(ctx, "Unable to find previously tagged artifacts, rolling back, tag: [%v], err [%v]", tag, err) tx.Rollback() - return h.errorTransformer.ToDataCatalogError(listArtifactsScope.Error) + return h.errorTransformer.ToDataCatalogError(err) } // 3. Remove the tags from the currently tagged artifacts @@ -122,7 +122,7 @@ func (h *tagRepo) Create(ctx context.Context, tag models.Tag) error { return h.errorTransformer.ToDataCatalogError(tx.Error) } - // 5. Tag the new artifact + // 5. Tag the new artifact, if it didn't previously exist if undeleteScope.RowsAffected == 0 { if err := tx.Create(&tag).Error; err != nil { logger.Errorf(ctx, "Unable to create tag, rolling back, tag: [%v], err [%v]", tag, err) diff --git a/pkg/repositories/gormimpl/tag_test.go b/pkg/repositories/gormimpl/tag_test.go index e0cd9169..442b4b8d 100644 --- a/pkg/repositories/gormimpl/tag_test.go +++ b/pkg/repositories/gormimpl/tag_test.go @@ -59,7 +59,7 @@ func TestCreateTagNew(t *testing.T) { `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (123)))`).WithReply(getDBPartitionResponse(newArtifact)) GlobalMock.NewMock().WithQuery( - `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at = NULL)) LIMIT 100 OFFSET 0`).WithReply([]map[string]interface{}{}) + `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at IS NULL)) LIMIT 100 OFFSET 0`).WithReply([]map[string]interface{}{}) GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback( @@ -95,7 +95,7 @@ func TestStealOldTag(t *testing.T) { `SELECT * FROM "partitions" WHERE "partitions"."deleted_at" IS NULL AND (("artifact_id" IN (111)))`).WithReply(getDBPartitionResponse(newArtifact)) GlobalMock.NewMock().WithQuery( - `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at = NULL)) LIMIT 100 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact)) + `SELECT "artifacts".* FROM "artifacts" JOIN partitions partitions0 ON artifacts.artifact_id = partitions0.artifact_id JOIN tags tags1 ON artifacts.artifact_id = tags1.artifact_id WHERE "artifacts"."deleted_at" IS NULL AND ((partitions0.key = region) AND (partitions0.value = SEA) AND (tags1.tag_name = test-tagname) AND (tags1.deleted_at IS NULL)) LIMIT 100 OFFSET 0`).WithReply(getDBArtifactResponse(oldArtifact)) GlobalMock.NewMock().WithQuery( `INSERT INTO "tags" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","tag_name","artifact_id","dataset_uuid") VALUES (?,?,?,?,?,?,?,?,?,?)`).WithCallback(