Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix preset sharding options and add tests #1

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ dependencies {

test {
useJUnit()

maxHeapSize = "2g"
}

jar {
Expand Down
101 changes: 85 additions & 16 deletions src/main/java/com/glencoesoftware/zarr/Convert.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.Callable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ch.qos.logback.classic.Level;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import picocli.CommandLine.Parameters;
Expand All @@ -61,6 +63,7 @@ public class Convert implements Callable<Integer> {

private String inputLocation;
private String outputLocation;
private String logLevel = "INFO";
private boolean writeV2;

private ShardConfiguration shardConfig;
Expand Down Expand Up @@ -90,6 +93,26 @@ public void setOutput(String output) {
outputLocation = output;
}

/**
* Set the slf4j logging level. Defaults to "INFO".
*
* @param level logging level
*/
@Option(
names = {"--log-level", "--debug"},
arity = "0..1",
description = "Change logging level; valid values are " +
"OFF, ERROR, WARN, INFO, DEBUG, TRACE and ALL. " +
"(default: ${DEFAULT-VALUE})",
defaultValue = "INFO",
fallbackValue = "DEBUG"
)
public void setLogLevel(String level) {
if (level != null) {
logLevel = level;
}
}

@Option(
names = "--write-v2",
description = "Read v3, write v2",
Expand Down Expand Up @@ -134,6 +157,10 @@ public void setCompression(String[] compression) {

@Override
public Integer call() throws Exception {
ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger)
LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME);
root.setLevel(Level.toLevel(logLevel));

if (writeV2) {
convertToV2();
}
Expand All @@ -151,6 +178,7 @@ public void convertToV3() throws Exception {
Path inputPath = Paths.get(inputLocation);

// get the root-level attributes
LOGGER.debug("opening v2 root group: {}", inputPath);
ZarrGroup reader = ZarrGroup.open(inputPath);
Map<String, Object> attributes = reader.getAttributes();

Expand All @@ -163,6 +191,7 @@ public void convertToV3() throws Exception {
// but this doesn't seem to actually create the group
// separating the group creation and attribute writing into
// two calls seems to work correctly
LOGGER.debug("opening v3 root group: {}", outputLocation);
FilesystemStore outputStore = new FilesystemStore(outputLocation);
Group outputRootGroup = Group.create(outputStore.resolve());
outputRootGroup.setAttributes(attributes);
Expand All @@ -175,9 +204,11 @@ public void convertToV3() throws Exception {

for (String seriesGroupKey : groupKeys) {
if (seriesGroupKey.indexOf("/") > 0) {
LOGGER.debug("skipping v2 group key: {}", seriesGroupKey);
continue;
}
Path seriesPath = inputPath.resolve(seriesGroupKey);
LOGGER.debug("opening v2 group: {}", seriesPath);
ZarrGroup seriesGroup = ZarrGroup.open(seriesPath);
LOGGER.info("opened {}", seriesPath);

Expand All @@ -190,13 +221,16 @@ public void convertToV3() throws Exception {
Set<String> columnKeys = seriesGroup.getGroupKeys();
// "pass through" if this is not HCS
if (columnKeys.size() == 0) {
LOGGER.debug("no column group keys (likely not HCS)");
columnKeys.add("");
}
for (String columnKey : columnKeys) {
if (columnKey.indexOf("/") > 0) {
LOGGER.debug("skipping v2 column group key: {}", columnKey);
continue;
}
Path columnPath = columnKey.isEmpty() ? seriesPath : seriesPath.resolve(columnKey);
LOGGER.debug("opening v2 group: {}", columnPath);
ZarrGroup column = ZarrGroup.open(columnPath);

if (!columnKey.isEmpty()) {
Expand All @@ -208,14 +242,15 @@ public void convertToV3() throws Exception {
Set<String> fieldKeys = column.getGroupKeys();
// "pass through" if this is not HCS
if (fieldKeys.size() == 0) {
LOGGER.debug("no field group keys");
fieldKeys.add("");
}

for (String fieldKey : fieldKeys) {
Path fieldPath = fieldKey.isEmpty() ? columnPath : columnPath.resolve(fieldKey);
LOGGER.debug("opening v2 field group: {}", fieldPath);
ZarrGroup field = ZarrGroup.open(fieldPath);


Map<String, Object> fieldAttributes = field.getAttributes();
if (!fieldKey.isEmpty()) {
Group outputFieldGroup = Group.create(outputStore.resolve(seriesGroupKey, columnKey, fieldKey));
Expand All @@ -239,12 +274,16 @@ public void convertToV3() throws Exception {

for (int res=0; res<totalResolutions; res++) {
String resolutionPath = fieldPath + "/" + res;
LOGGER.debug("opening v2 array: {}", resolutionPath);

ZarrArray tile = field.openArray("/" + res);
LOGGER.info("opened array {}", resolutionPath);
int[] chunkSizes = tile.getChunks();
int[] originalChunkSizes = tile.getChunks();
int[] shape = tile.getShape();

int[] chunkSizes = new int[originalChunkSizes.length];
System.arraycopy(originalChunkSizes, 0, chunkSizes, 0, chunkSizes.length);

int[] gridPosition = new int[] {0, 0, 0, 0, 0};
int tileX = chunkSizes[chunkSizes.length - 2];
int tileY = chunkSizes[chunkSizes.length - 1];
Expand All @@ -257,22 +296,31 @@ public void convertToV3() throws Exception {
if (shardConfig != null) {
switch (shardConfig) {
case SINGLE:
codecBuilder = codecBuilder.withSharding(shape);
// single shard covering the whole image
// internal chunk sizes remain the same as in input data
chunkSizes = shape;
break;
case CHUNK:
codecBuilder = codecBuilder.withSharding(chunkSizes);
// exactly one shard per chunk
// no changes needed
break;
case SUPERCHUNK:
int[] shardSize = new int[chunkSizes.length];
System.arraycopy(chunkSizes, 0, shardSize, 0, shardSize.length);
shardSize[4] *= 2;
shardSize[3] *= 2;
codecBuilder = codecBuilder.withSharding(shardSize);
// each shard covers 2x2 chunks
chunkSizes[4] *= 2;
chunkSizes[3] *= 2;
break;
case CUSTOM:
// TODO
break;
}

if (chunkAndShardCompatible(originalChunkSizes, chunkSizes, shape)) {
codecBuilder = codecBuilder.withSharding(originalChunkSizes);
}
else {
LOGGER.warn("Skipping sharding due to incompatible sizes");
chunkSizes = originalChunkSizes;
}
}
if (codecs != null) {
for (String codecName : codecs) {
Expand All @@ -292,19 +340,21 @@ else if (codecName.equals("blosc")) {
}
final CodecBuilder builder = codecBuilder;

Array outputArray = Array.create(outputStore.resolve(seriesGroupKey, columnKey, fieldKey, String.valueOf(res)),
StoreHandle v3ArrayHandle = outputStore.resolve(seriesGroupKey, columnKey, fieldKey, String.valueOf(res));
LOGGER.debug("opening v3 array: {}", v3ArrayHandle);
Array outputArray = Array.create(v3ArrayHandle,
Array.metadataBuilder()
.withShape(Utils.toLongArray(shape))
.withDataType(getV3Type(type))
.withChunkShape(chunkSizes)
.withChunkShape(chunkSizes) // if sharding is used, this will be the shard size
.withFillValue(255)
.withCodecs(c -> builder)
.build()
);

for (int t=0; t<shape[0]; t+=chunkSizes[0]) {
for (int c=0; c<shape[1]; c+=chunkSizes[1]) {
for (int z=0; z<shape[2]; z+=chunkSizes[2]) {
for (int t=0; t<shape[0]; t+=originalChunkSizes[0]) {
for (int c=0; c<shape[1]; c+=originalChunkSizes[1]) {
for (int z=0; z<shape[2]; z+=originalChunkSizes[2]) {
// copy each chunk, keeping the original chunk sizes
for (int y=0; y<shape[4]; y+=tileY) {
for (int x=0; x<shape[3]; x+=tileX) {
Expand All @@ -313,8 +363,10 @@ else if (codecName.equals("blosc")) {
gridPosition[2] = z;
gridPosition[1] = c;
gridPosition[0] = t;
Object bytes = tile.read(chunkSizes, gridPosition);
outputArray.write(Utils.toLongArray(gridPosition), NetCDF_Util.createArrayWithGivenStorage(bytes, chunkSizes));
LOGGER.debug("copying chunk of size {} at position {}",
Arrays.toString(originalChunkSizes), Arrays.toString(gridPosition));
Object bytes = tile.read(originalChunkSizes, gridPosition);
outputArray.write(Utils.toLongArray(gridPosition), NetCDF_Util.createArrayWithGivenStorage(bytes, originalChunkSizes));
}
}
}
Expand Down Expand Up @@ -530,6 +582,23 @@ private DataType getV2Type(dev.zarr.zarrjava.v3.DataType v3) {
throw new IllegalArgumentException(v3.toString());
}

/**
* Check that the desired chunk, shard, and shape are compatible with each other.
* In each dimension, the chunk size must evenly divide into the shard size,
* which must evenly divide into the shape.
*/
private boolean chunkAndShardCompatible(int[] chunkSize, int[] shardSize, int[] shape) {
for (int d=0; d<shape.length; d++) {
if (shape[d] % shardSize[d] != 0) {
return false;
}
if (shardSize[d] % chunkSize[d] != 0) {
return false;
}
}
return true;
}

public static void main(String[] args) {
CommandLine.call(new Convert(), args);
}
Expand Down
59 changes: 59 additions & 0 deletions src/test/java/com/glencoesoftware/zarr/test/ConversionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,65 @@ public void testCodecs() throws Exception {
}
}

/**
* Test different sharding options
*/
@Test
public void testSharding() throws Exception {
input = fake("sizeX", "10240", "sizeY", "10240");
assertBioFormats2Raw();

String[] shardOptions = new String[] {
"SINGLE", "CHUNK", "SUPERCHUNK"
};
int[][] shardSizes = new int[][] {
{1, 1, 1, 10240, 10240},
{1, 1, 1, 1024, 1024},
{1, 1, 1, 2048, 2048}
};

for (int opt=0; opt<shardOptions.length; opt++) {
// first convert v2 produced by bioformats2raw to v3
Path v3Output = tmp.newFolder().toPath().resolve("v3-test");
Convert v3Converter = new Convert();
v3Converter.setInput(output.toString());
v3Converter.setOutput(v3Output.toString());

v3Converter.setSharding(shardOptions[opt]);
v3Converter.convertToV3();

// check list of codecs in the v3 arrays

Store store = new FilesystemStore(v3Output);
Array resolution = Array.open(store.resolve("0", "0"));

int[] shardSize = shardSizes[opt];
Assert.assertArrayEquals(resolution.metadata.chunkShape(), shardSize);

// now convert v3 back to v2
Path roundtripOutput = tmp.newFolder().toPath().resolve("v2-roundtrip-test");
Convert v2Converter = new Convert();
v2Converter.setInput(v3Output.toString());
v2Converter.setOutput(roundtripOutput.toString());
v2Converter.setWriteV2(true);
v2Converter.convertToV2();

Path originalOMEXML = output.resolve("OME").resolve("METADATA.ome.xml");
Path roundtripOMEXML = roundtripOutput.resolve("OME").resolve("METADATA.ome.xml");

// make sure the OME-XML is present and not changed
Assert.assertEquals(Files.readAllLines(originalOMEXML), Files.readAllLines(roundtripOMEXML));

// since the image is small, make sure all pixels are identical in both resolutions
for (int r=0; r<7; r++) {
ZarrArray original = ZarrGroup.open(output.resolve("0")).openArray(String.valueOf(r));
ZarrArray roundtrip = ZarrGroup.open(roundtripOutput.resolve("0")).openArray(String.valueOf(r));

compareZarrArrays(original, roundtrip);
}
}
}

private void compareZarrArrays(ZarrArray original, ZarrArray roundtrip) throws Exception {
Assert.assertArrayEquals(original.getShape(), roundtrip.getShape());

Expand Down