diff --git a/server/src/main/java/com/cloud/api/query/QueryManagerImpl.java b/server/src/main/java/com/cloud/api/query/QueryManagerImpl.java index 1808230d8855..be695ac96c64 100644 --- a/server/src/main/java/com/cloud/api/query/QueryManagerImpl.java +++ b/server/src/main/java/com/cloud/api/query/QueryManagerImpl.java @@ -29,7 +29,6 @@ import java.util.Map; import java.util.Set; import java.util.UUID; -import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -3813,11 +3812,62 @@ else if (!template.isPublicTemplate() && caller.getType() != Account.Type.ADMIN) } } + applyPublicTemplateSharingRestrictions(sc, caller); + return templateChecks(isIso, hypers, tags, name, keyword, hyperType, onlyReady, bootable, zoneId, showDomr, caller, showRemovedTmpl, parentTemplateId, showUnique, searchFilter, sc); } + /** + * If the caller is not a root admin, restricts the search to return only public templates from the domain which + * the caller belongs to and domains with the setting 'share.public.templates.with.other.domains' enabled. + */ + protected void applyPublicTemplateSharingRestrictions(SearchCriteria sc, Account caller) { + if (caller.getType() == Account.Type.ADMIN) { + s_logger.debug(String.format("Account [%s] is a root admin. Therefore, it has access to all public templates.", caller)); + return; + } + + List publicTemplates = _templateJoinDao.listPublicTemplates(); + + Set unsharableDomainIds = new HashSet<>(); + for (TemplateJoinVO template : publicTemplates) { + addDomainIdToSetIfDomainDoesNotShareTemplates(template.getDomainId(), caller, unsharableDomainIds); + } + + if (!unsharableDomainIds.isEmpty()) { + s_logger.info(String.format("The public templates belonging to the domains [%s] will not be listed to account [%s] as they have the configuration [%s] marked as 'false'.", unsharableDomainIds, caller, QueryService.SharePublicTemplatesWithOtherDomains.key())); + sc.addAnd("domainId", SearchCriteria.Op.NOTIN, unsharableDomainIds.toArray()); + } + } + + /** + * Adds the provided domain ID to the set if the domain does not share templates with the account. That is, if: + * (1) the template does not belong to the domain of the account AND + * (2) the domain of the template has the setting 'share.public.templates.with.other.domains' disabled. + */ + protected void addDomainIdToSetIfDomainDoesNotShareTemplates(long domainId, Account account, Set unsharableDomainIds) { + if (domainId == account.getDomainId()) { + s_logger.trace(String.format("Domain [%s] will not be added to the set of domains with unshared templates since the account [%s] belongs to it.", domainId, account)); + return; + } + + if (unsharableDomainIds.contains(domainId)) { + s_logger.trace(String.format("Domain [%s] is already on the set of domains with unshared templates.", domainId)); + return; + } + + if (!checkIfDomainSharesTemplates(domainId)) { + s_logger.debug(String.format("Domain [%s] will be added to the set of domains with unshared templates as configuration [%s] is false.", domainId, QueryService.SharePublicTemplatesWithOtherDomains.key())); + unsharableDomainIds.add(domainId); + } + } + + protected boolean checkIfDomainSharesTemplates(Long domainId) { + return QueryService.SharePublicTemplatesWithOtherDomains.valueIn(domainId); + } + private Pair, Integer> templateChecks(boolean isIso, List hypers, Map tags, String name, String keyword, HypervisorType hyperType, boolean onlyReady, Boolean bootable, Long zoneId, boolean showDomr, Account caller, boolean showRemovedTmpl, Long parentTemplateId, Boolean showUnique, @@ -3947,27 +3997,9 @@ private Pair, Integer> findTemplatesByIdOrTempZonePair(Pair templates = _templateJoinDao.searchByTemplateZonePair(showRemoved, templateZonePairs); } - if(caller.getType() != Account.Type.ADMIN) { - templates = applyPublicTemplateRestriction(templates, caller); - count = templates.size(); - } - return new Pair, Integer>(templates, count); } - private List applyPublicTemplateRestriction(List templates, Account caller){ - List unsharableDomainIds = templates.stream() - .map(TemplateJoinVO::getDomainId) - .distinct() - .filter(domainId -> domainId != caller.getDomainId()) - .filter(Predicate.not(QueryService.SharePublicTemplatesWithOtherDomains::valueIn)) - .collect(Collectors.toList()); - - return templates.stream() - .filter(Predicate.not(t -> unsharableDomainIds.contains(t.getDomainId()))) - .collect(Collectors.toList()); - } - @Override public ListResponse listIsos(ListIsosCmd cmd) { Pair, Integer> result = searchForIsosInternal(cmd); diff --git a/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDao.java b/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDao.java index 58cb886594fd..1b7edd325922 100644 --- a/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDao.java +++ b/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDao.java @@ -48,6 +48,8 @@ public interface TemplateJoinDao extends GenericDao { List listActiveTemplates(long storeId); + List listPublicTemplates(); + Pair, Integer> searchIncludingRemovedAndCount(final SearchCriteria sc, final Filter filter); List findByDistinctIds(Long... ids); diff --git a/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDaoImpl.java b/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDaoImpl.java index f20a9aa2e133..4fe0f200741e 100644 --- a/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDaoImpl.java +++ b/server/src/main/java/com/cloud/api/query/dao/TemplateJoinDaoImpl.java @@ -104,6 +104,8 @@ public class TemplateJoinDaoImpl extends GenericDaoBaseWithTagInformation activeTmpltSearch; + private final SearchBuilder publicTmpltSearch; + protected TemplateJoinDaoImpl() { tmpltIdPairSearch = createSearchBuilder(); @@ -137,6 +139,10 @@ protected TemplateJoinDaoImpl() { activeTmpltSearch.cp(); activeTmpltSearch.done(); + publicTmpltSearch = createSearchBuilder(); + publicTmpltSearch.and("public", publicTmpltSearch.entity().isPublicTemplate(), SearchCriteria.Op.EQ); + publicTmpltSearch.done(); + // select distinct pair (template_id, zone_id) _count = "select count(distinct temp_zone_pair) from template_view WHERE "; } @@ -572,6 +578,13 @@ public List listActiveTemplates(long storeId) { return searchIncludingRemoved(sc, null, null, false); } + @Override + public List listPublicTemplates() { + SearchCriteria sc = publicTmpltSearch.create(); + sc.setParameters("public", Boolean.TRUE); + return listBy(sc); + } + @Override public Pair, Integer> searchIncludingRemovedAndCount(final SearchCriteria sc, final Filter filter) { List objects = searchIncludingRemoved(sc, filter, null, false); diff --git a/server/src/test/java/com/cloud/api/query/QueryManagerImplTest.java b/server/src/test/java/com/cloud/api/query/QueryManagerImplTest.java index 9323e7f0a3e7..29361bfeeaea 100644 --- a/server/src/test/java/com/cloud/api/query/QueryManagerImplTest.java +++ b/server/src/test/java/com/cloud/api/query/QueryManagerImplTest.java @@ -17,7 +17,9 @@ package com.cloud.api.query; +import com.cloud.api.query.dao.TemplateJoinDao; import com.cloud.api.query.vo.EventJoinVO; +import com.cloud.api.query.vo.TemplateJoinVO; import com.cloud.event.dao.EventJoinDao; import com.cloud.exception.InvalidParameterValueException; import com.cloud.exception.PermissionDeniedException; @@ -48,10 +50,13 @@ import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.mockito.Spy; import org.mockito.junit.MockitoJUnitRunner; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.UUID; import static org.mockito.Mockito.when; @@ -61,13 +66,28 @@ public class QueryManagerImplTest { public static final long USER_ID = 1; public static final long ACCOUNT_ID = 1; + @Spy + @InjectMocks + private QueryManagerImpl queryManagerImplSpy = new QueryManagerImpl(); + @Mock EntityManager entityManager; + @Mock AccountManager accountManager; + @Mock EventJoinDao eventJoinDao; + @Mock + Account accountMock; + + @Mock + TemplateJoinDao templateJoinDaoMock; + + @Mock + SearchCriteria searchCriteriaMock; + private AccountVO account; private UserVO user; @@ -176,4 +196,67 @@ public void searchForEventsFailPermissionDenied() { Mockito.doThrow(new PermissionDeniedException("Denied")).when(accountManager).checkAccess(account, SecurityChecker.AccessType.ListEntry, false, network); queryManager.searchForEvents(cmd); } + + @Test + public void applyPublicTemplateRestrictionsTestDoesNotApplyRestrictionsWhenCallerIsRootAdmin() { + Mockito.when(accountMock.getType()).thenReturn(Account.Type.ADMIN); + + queryManagerImplSpy.applyPublicTemplateSharingRestrictions(searchCriteriaMock, accountMock); + + Mockito.verify(searchCriteriaMock, Mockito.never()).addAnd(Mockito.anyString(), Mockito.any(), Mockito.any()); + } + + @Test + public void applyPublicTemplateRestrictionsTestAppliesRestrictionsWhenCallerIsNotRootAdmin() { + long callerDomainId = 1L; + long sharableDomainId = 2L; + long unsharableDomainId = 3L; + + Mockito.when(accountMock.getType()).thenReturn(Account.Type.NORMAL); + + Mockito.when(accountMock.getDomainId()).thenReturn(callerDomainId); + TemplateJoinVO templateMock1 = Mockito.mock(TemplateJoinVO.class); + Mockito.when(templateMock1.getDomainId()).thenReturn(callerDomainId); + Mockito.lenient().doReturn(false).when(queryManagerImplSpy).checkIfDomainSharesTemplates(callerDomainId); + + TemplateJoinVO templateMock2 = Mockito.mock(TemplateJoinVO.class); + Mockito.when(templateMock2.getDomainId()).thenReturn(sharableDomainId); + Mockito.doReturn(true).when(queryManagerImplSpy).checkIfDomainSharesTemplates(sharableDomainId); + + TemplateJoinVO templateMock3 = Mockito.mock(TemplateJoinVO.class); + Mockito.when(templateMock3.getDomainId()).thenReturn(unsharableDomainId); + Mockito.doReturn(false).when(queryManagerImplSpy).checkIfDomainSharesTemplates(unsharableDomainId); + + List publicTemplates = List.of(templateMock1, templateMock2, templateMock3); + Mockito.when(templateJoinDaoMock.listPublicTemplates()).thenReturn(publicTemplates); + + queryManagerImplSpy.applyPublicTemplateSharingRestrictions(searchCriteriaMock, accountMock); + + Mockito.verify(searchCriteriaMock).addAnd("domainId", SearchCriteria.Op.NOTIN, unsharableDomainId); + } + + @Test + public void addDomainIdToSetIfDomainDoesNotShareTemplatesTestDoesNotAddWhenCallerBelongsToDomain() { + long domainId = 1L; + Set set = new HashSet<>(); + + Mockito.when(accountMock.getDomainId()).thenReturn(domainId); + + queryManagerImplSpy.addDomainIdToSetIfDomainDoesNotShareTemplates(domainId, accountMock, set); + + Assert.assertEquals(0, set.size()); + } + + @Test + public void addDomainIdToSetIfDomainDoesNotShareTemplatesTestAddsWhenDomainDoesNotShareTemplates() { + long domainId = 1L; + Set set = new HashSet<>(); + + Mockito.when(accountMock.getDomainId()).thenReturn(2L); + Mockito.doReturn(false).when(queryManagerImplSpy).checkIfDomainSharesTemplates(domainId); + + queryManagerImplSpy.addDomainIdToSetIfDomainDoesNotShareTemplates(domainId, accountMock, set); + + Assert.assertTrue(set.contains(domainId)); + } }