Skip to content

Commit

Permalink
A batch query utility to replace TransactionManager's loadAllOf met…
Browse files Browse the repository at this point in the history
…hods (#2589)

* Replace  with batch query

* Addressing CR
  • Loading branch information
weiminyu authored Oct 14, 2024
1 parent 020ed33 commit 634202c
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright 2024 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package google.registry.persistence.transaction;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static google.registry.persistence.PersistenceModule.TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import com.google.common.collect.UnmodifiableIterator;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.metamodel.EntityType;
import java.util.Optional;
import java.util.stream.Stream;

/** Helper for querying large data sets in batches. */
public final class BatchedQueries {

private BatchedQueries() {}

private static final int DEFAULT_BATCH_SIZE = 500;

public static <T> Stream<ImmutableList<T>> loadAllOf(Class<T> entityType) {
return loadAllOf(entityType, DEFAULT_BATCH_SIZE);
}

public static <T> Stream<ImmutableList<T>> loadAllOf(Class<T> entityType, int batchSize) {
return loadAllOf(tm(), entityType, batchSize);
}

/**
* Loads all entities of type {@code T} in batches.
*
* <p>This method must not be nested in any transaction; same for the traversal of the returned
* {@link Stream}. Each batch is loaded in a separate transaction at the {@code
* TRANSACTION_REPEATABLE_READ} isolation level, and loads the snapshot of the batch at the
* batch's start time. New insertions or updates since then are not reflected in the result.
*/
public static <T> Stream<ImmutableList<T>> loadAllOf(
JpaTransactionManager jpaTm, Class<T> entityType, int batchSize) {
checkState(!jpaTm.inTransaction(), "loadAllOf cannot be nested in a transaction");
checkArgument(batchSize > 0, "batchSize must be positive");
EntityType<T> jpaEntityType = jpaTm.getMetaModel().entity(entityType);
if (!jpaEntityType.hasSingleIdAttribute()) {
// We should support multi-column primary keys on a case-by-case basis.
throw new UnsupportedOperationException(
"Types with multi-column primary key not supported yet");
}
return Streams.stream(
new BatchedIterator<>(new SingleColIdBatchQuery<>(jpaTm, jpaEntityType), batchSize));
}

public interface BatchQuery<T> {
ImmutableList<T> readBatch(Optional<T> lastRead, int batchSize);
}

private static class SingleColIdBatchQuery<T> implements BatchQuery<T> {

private final JpaTransactionManager jpaTm;
private final Class<T> entityType;
private final String initialJpqlQuery;
private final String subsequentJpqlTemplate;

private SingleColIdBatchQuery(JpaTransactionManager jpaTm, EntityType<T> jpaEntityType) {
checkArgument(
jpaEntityType.hasSingleIdAttribute(),
"%s must have a single ID attribute",
jpaEntityType.getJavaType().getSimpleName());
this.jpaTm = jpaTm;
this.entityType = jpaEntityType.getJavaType();
var idAttr = jpaEntityType.getId(jpaEntityType.getIdType().getJavaType());
this.initialJpqlQuery =
String.format("FROM %s ORDER BY %s", jpaEntityType.getName(), idAttr.getName());
this.subsequentJpqlTemplate =
String.format(
"FROM %1$s WHERE %2$s > :id ORDER BY %2$s",
jpaEntityType.getName(), idAttr.getName());
}

@Override
public ImmutableList<T> readBatch(Optional<T> lastRead, int batchSize) {
checkState(!jpaTm.inTransaction(), "Stream cannot be accessed in a transaction");
return jpaTm.transact(
TRANSACTION_REPEATABLE_READ,
() -> {
var entityManager = jpaTm.getEntityManager();
Optional<Object> lastReadId =
lastRead.map(
entityManager.getEntityManagerFactory().getPersistenceUnitUtil()
::getIdentifier);
TypedQuery<T> query =
lastRead.isEmpty()
? entityManager.createQuery(initialJpqlQuery, entityType)
: entityManager
.createQuery(subsequentJpqlTemplate, entityType)
.setParameter("id", lastReadId.get());

var results = ImmutableList.copyOf(query.setMaxResults(batchSize).getResultList());
results.forEach(entityManager::detach);
return results;
});
}
}

private static class BatchedIterator<T> extends UnmodifiableIterator<ImmutableList<T>> {

private final BatchQuery<T> batchQuery;

private final int batchSize;

private ImmutableList<T> cachedBatch = null;

private BatchedIterator(BatchQuery<T> batchQuery, int batchSize) {
this.batchQuery = batchQuery;
this.batchSize = batchSize;
this.cachedBatch = readNextBatch();
}

@Override
public boolean hasNext() {
return !cachedBatch.isEmpty();
}

@Override
public ImmutableList<T> next() {
var toReturn = cachedBatch;
cachedBatch = cachedBatch.size() < batchSize ? ImmutableList.of() : readNextBatch();
return toReturn;
}

private ImmutableList<T> readNextBatch() {
Optional<T> lastRead =
cachedBatch == null
? Optional.empty()
: Optional.ofNullable(Iterables.getLast(cachedBatch, null));
return batchQuery.readBatch(lastRead, batchSize);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jakarta.persistence.Query;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.metamodel.Metamodel;

/** Sub-interface of {@link TransactionManager} which defines JPA related methods. */
public interface JpaTransactionManager extends TransactionManager {
Expand All @@ -31,6 +32,9 @@ public interface JpaTransactionManager extends TransactionManager {
*/
EntityManager getStandaloneEntityManager();

/** Returns the JPA {@link Metamodel}. */
Metamodel getMetaModel();

/**
* Returns the {@link EntityManager} for the current request.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.metamodel.EntityType;
import jakarta.persistence.metamodel.Metamodel;
import java.io.Serializable;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
Expand Down Expand Up @@ -116,6 +117,11 @@ public EntityManager getStandaloneEntityManager() {
return emf.createEntityManager();
}

@Override
public Metamodel getMetaModel() {
return this.emf.getMetamodel();
}

@Override
public EntityManager getEntityManager() {
assertInTransaction();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
package google.registry.tools.server;

import static com.google.common.collect.ImmutableSortedSet.toImmutableSortedSet;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.persistence.transaction.BatchedQueries.loadAllOf;
import static google.registry.request.Action.Method.GET;
import static google.registry.request.Action.Method.POST;
import static java.util.Comparator.comparing;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import google.registry.model.EppResourceUtils;
import google.registry.model.host.Host;
Expand Down Expand Up @@ -51,7 +52,8 @@ public ImmutableSet<String> getPrimaryKeyFields() {
@Override
public ImmutableSet<Host> loadObjects() {
final DateTime now = clock.nowUtc();
return tm().transact(() -> tm().loadAllOf(Host.class)).stream()
return loadAllOf(Host.class)
.flatMap(ImmutableList::stream)
.filter(host -> EppResourceUtils.isActive(host, now))
.collect(toImmutableSortedSet(comparing(Host::getHostName)));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2024 The Nomulus Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package google.registry.persistence.transaction;

import static com.google.common.truth.Truth.assertThat;
import static google.registry.persistence.transaction.BatchedQueries.loadAllOf;
import static google.registry.persistence.transaction.TransactionManagerFactory.tm;
import static google.registry.testing.DatabaseHelper.persistResource;

import com.google.common.collect.ImmutableList;
import google.registry.model.ImmutableObject;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class BatchedQueriesTest {

@RegisterExtension
final JpaTestExtensions.JpaUnitTestExtension jpa =
new JpaTestExtensions.Builder()
.withEntityClass(LongIdEntity.class, StringIdEntity.class)
.buildUnitTestExtension();

@Test
void loadAllOf_noData() {
assertThat(loadAllOf(StringIdEntity.class)).isEmpty();
}

@Test
void loadAllOf_oneEntry() {
StringIdEntity entity = persistResource(new StringIdEntity("C1"));
assertThat(loadAllOf(StringIdEntity.class)).containsExactly(ImmutableList.of(entity));
}

@Test
void loadAllOf_multipleEntries_fullBatches() {
// Insert in reverse order. In practice the result of "FROM Contact" will be in this order.
// This tests that the `order by` clause is present in the query.
StringIdEntity entity4 = persistResource(new StringIdEntity("C4"));
StringIdEntity entity3 = persistResource(new StringIdEntity("C3"));
StringIdEntity entity2 = persistResource(new StringIdEntity("C2"));
StringIdEntity entity1 = persistResource(new StringIdEntity("C1"));
assertThat(loadAllOf(StringIdEntity.class, 2))
.containsExactly(ImmutableList.of(entity1, entity2), ImmutableList.of(entity3, entity4))
.inOrder();
}

@Test
void loadAllOf_multipleEntries_withPartialBatch() {
StringIdEntity entity1 = persistResource(new StringIdEntity("C1"));
StringIdEntity entity2 = persistResource(new StringIdEntity("C2"));
StringIdEntity entity3 = persistResource(new StringIdEntity("C3"));
StringIdEntity entity4 = persistResource(new StringIdEntity("C4"));
assertThat(loadAllOf(StringIdEntity.class, 3))
.containsExactly(ImmutableList.of(entity1, entity2, entity3), ImmutableList.of(entity4))
.inOrder();
}

@Test
void loadAllOf_multipleEntries_withLongNumberAsId() {
LongIdEntity testEntity2 = new LongIdEntity(2L);
LongIdEntity testEntity10 = new LongIdEntity(10L);
tm().transact(() -> tm().put(testEntity2));
tm().transact(() -> tm().put(testEntity10));

assertThat(loadAllOf(LongIdEntity.class, 1))
.containsExactly(ImmutableList.of(testEntity2), ImmutableList.of(testEntity10))
.inOrder();
}

@Entity(name = "StringIdEntity")
static class StringIdEntity extends ImmutableObject {
@Id String id;

StringIdEntity() {}

private StringIdEntity(String id) {
this.id = id;
}
}

@Entity(name = "LongIdEntity")
private static class LongIdEntity extends ImmutableObject {
@Id long entityId;

LongIdEntity() {}

private LongIdEntity(long id) {
this.entityId = id;
}
}
}

0 comments on commit 634202c

Please sign in to comment.