diff --git a/base/database/testing.go b/base/database/testing.go index 982b73865..ce3e6df6e 100644 --- a/base/database/testing.go +++ b/base/database/testing.go @@ -9,7 +9,9 @@ import ( "testing" "time" + "github.com/lib/pq" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func DebugWithCachesCheck(part string, fun func()) { @@ -173,16 +175,6 @@ func CheckAdvisoriesAccountDataNotified(t *testing.T, rhAccountID int, advisoryI } } -func CheckSystemUpdatesCount(t *testing.T, accountID int, systemID int64) []int { - var cnt []int - assert.NoError(t, Db.Table("system_package spkg"). - Select("json_array_length(spkg.update_data::json) as len"). - Where("spkg.update_data is not null"). - Where("spkg.system_id = ? AND spkg.rh_account_id = ? ", systemID, accountID). - Pluck("len", &cnt).Error) - return cnt -} - func CreateReportedAdvisories(reportedAdvisories []string, status []int) map[string]int { reportedAdvisoriesMap := make(map[string]int, len(reportedAdvisories)) for i, adv := range reportedAdvisories { @@ -256,13 +248,27 @@ func CheckEVRAsInDBSynced(t *testing.T, nExpected int, synced bool, evras ...str } func CheckSystemPackages(t *testing.T, accountID int, systemID int64, nExpected int, packageIDs ...int64) { - var systemPackages []models.SystemPackage - query := Db.Where("rh_account_id = ? AND system_id = ?", accountID, systemID) + // check system_package_data + var foundIDs []int64 + sysQuery := Db.Table(`(SELECT jsonb_object_keys(update_data)::bigint as package_id + FROM system_package_data + WHERE rh_account_id = ? AND system_id = ?) as t`, accountID, systemID) + if len(packageIDs) > 0 { + sysQuery = sysQuery.Where("package_id in (?)", packageIDs) + } + assert.Nil(t, sysQuery.Pluck("package_id", &foundIDs).Error) + assert.Equal(t, nExpected, len(foundIDs)) + + // check package_system_data + var foundNameIDs []int64 + pkgQuery := Db.Table("package_system_data psd"). + Where("psd.rh_account_id = ? AND psd.update_data->? IS NOT NULL", accountID, strconv.FormatInt(systemID, 10)) if len(packageIDs) > 0 { - query = query.Where("package_id IN (?)", packageIDs) + pkgQuery = pkgQuery.Joins("JOIN package p ON p.name_id = psd.package_name_id"). + Where("p.id in (?)", packageIDs) } - assert.Nil(t, query.Find(&systemPackages).Error) - assert.Equal(t, nExpected, len(systemPackages)) + assert.Nil(t, pkgQuery.Pluck("package_name_id", &foundNameIDs).Error) + assert.Equal(t, nExpected, len(foundNameIDs)) } func CheckSystemRepos(t *testing.T, rhAccountID int, systemID int64, repoIDs []int64) { @@ -296,15 +302,34 @@ func DeleteAdvisoryAccountData(t *testing.T, rhAccountID int, advisoryIDs []int6 } func DeleteSystemPackages(t *testing.T, accountID int, systemID int64, pkgIDs ...int64) { - query := Db.Model(&models.SystemPackage{}).Where("rh_account_id = ? AND system_id = ?", accountID, systemID) + // delete system_package_data if len(pkgIDs) > 0 { - query = query.Where("package_id in (?)", pkgIDs) + keys := make([]string, len(pkgIDs)) + for i, pid := range pkgIDs { + keys[i] = strconv.FormatInt(pid, 10) + } + assert.Nil(t, Db.Model(&models.SystemPackageData{}). + Where("rh_account_id = ? and system_id = ?", accountID, systemID). + Update("update_data", gorm.Expr("update_data - ?", pq.StringArray(keys))).Error) + } else { + assert.Nil(t, Db.Where("rh_account_id = ? AND system_id = ?", accountID, systemID). + Delete(&models.SystemPackageData{}).Error) } - assert.Nil(t, query.Delete(&models.SystemPackage{}).Error) - var count int64 // ensure deleted - assert.Nil(t, query.Count(&count).Error) - assert.Equal(t, int64(0), count) + // delete package_system_data + systemIDKey := strconv.FormatInt(systemID, 10) + if len(pkgIDs) > 0 { + assert.Nil(t, Db.Exec(`UPDATE package_system_data psd SET update_data = update_data - ? + FROM package p + WHERE psd.rh_account_id = ? + AND psd.update_data->? IS NOT NULL + AND p.name_id = psd.package_name_id + AND p.id in (?)`, systemIDKey, accountID, systemIDKey, pkgIDs).Error) + } else { + query := Db.Table("package_system_data psd"). + Where("psd.rh_account_id = ? AND psd.update_data->? IS NOT NULL", accountID, systemIDKey) + assert.Nil(t, query.Update("update_data", gorm.Expr("update_data - ?", strconv.FormatInt(systemID, 10))).Error) + } } func DeleteSystemRepos(t *testing.T, rhAccountID int, systemID int64, repoIDs []int64) { diff --git a/evaluator/evaluate_test.go b/evaluator/evaluate_test.go index 711c160eb..9a341974c 100644 --- a/evaluator/evaluate_test.go +++ b/evaluator/evaluate_test.go @@ -135,14 +135,6 @@ func TestEvaluateYum(t *testing.T) { assert.Equal(t, 1, len(mockWriter.Messages)) } -func TestEvaluatePruneUpdates(t *testing.T) { - TestEvaluate(t) - count := database.CheckSystemUpdatesCount(t, rhAccountID, systemID) - for _, c := range count { - assert.LessOrEqual(t, c, 1, "All packages should only have single update stored") - } -} - func TestRun(t *testing.T) { configure() var nReaders int32