Skip to content

Commit

Permalink
Fix call to as_mars()
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 22, 2024
1 parent 126fc5c commit e5ee8d3
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 10 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ classifiers = [

dependencies = [
"cdsapi",
"earthkit-data>=0.10.1",
"earthkit-data>=0.10.3",
"earthkit-meteo",
"earthkit-regrid",
"eccodes>=2.37",
Expand Down Expand Up @@ -72,6 +72,7 @@ file = "ai_models.inputs.file:FileInput"
mars = "ai_models.inputs.mars:MarsInput"
cds = "ai_models.inputs.cds:CdsInput"
ecmwf-open-data = "ai_models.inputs.opendata:OpenDataInput"
opendata = "ai_models.inputs.opendata:OpenDataInput"

[project.entry-points."ai_models.output"]
file = "ai_models.outputs:FileOutput"
Expand Down
5 changes: 4 additions & 1 deletion src/ai_models/inputs/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging

import earthkit.regrid as ekr
import tqdm
from earthkit.data.indexing.fieldlist import FieldArray

from .transform import NewDataField
Expand All @@ -22,7 +23,9 @@ def __init__(self, grid, source):

def __call__(self, ds):
result = []
for f in ds:
for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False):
data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid))
result.append(NewDataField(f, data))

LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape)
return FieldArray(result)
15 changes: 9 additions & 6 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class OpenDataInput(RequestBasedInput):
WHERE = "OPENDATA"

RESOLS = {
(0.25, 0.25): ("0p25", (0.25, 0.25), False),
"N320": ("0p25", (0.25, 0.25), True),
"O96": ("0p25", (0.25, 0.25), True),
# (0.1, 0.1): ("0p25", (0.25, 0.25), False),
(0.25, 0.25): ("0p25", (0.25, 0.25), False, False),
"N320": ("0p25", (0.25, 0.25), True, False),
"O96": ("0p25", (0.25, 0.25), True, False),
(0.1, 0.1): ("0p25", (0.25, 0.25), True, True),
}

def __init__(self, owner, **kwargs):
Expand All @@ -56,12 +56,15 @@ def _adjust(self, kwargs):
if isinstance(grid, list):
grid = tuple(grid)

kwargs["resol"], source, interp = self.RESOLS[grid]
kwargs["resol"], source, interp, oversampling = self.RESOLS[grid]
r = dict(**kwargs)
r.update(self.owner.retrieve)

if interp:
logging.debug("Interpolating from %s to %s", source, grid)

logging.info("Interpolating input data from %s to %s.", source, grid)
if oversampling:
logging.warning("This will oversample the input data.")
return Interpolate(grid, source)
else:
return lambda x: x
Expand Down
1 change: 1 addition & 0 deletions src/ai_models/inputs/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class NewDataField(WrappedField):
def __init__(self, field, data):
super().__init__(field)
self._data = data
self.shape = data.shape

def to_numpy(self, flatten=False, dtype=None, index=None):
data = self._data
Expand Down
6 changes: 4 additions & 2 deletions src/ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def collect_archive_requests(self, written):
# does not return always return recently set keys
handle = handle.clone()

self.archiving[path].add(handle.as_mars())
self.archiving[path].add(handle.as_namespace("mars"))

def finalise(self):
self.output.flush()
Expand Down Expand Up @@ -536,6 +536,8 @@ def write_input_fields(
accumulations_shape=None,
ignore=None,
):
LOG.info("Starting date is %s", self.start_datetime)
LOG.info("Writing input fields")
if ignore is None:
ignore = []

Expand All @@ -553,7 +555,7 @@ def write_input_fields(

if accumulations is not None:
if accumulations_template is None:
accumulations_template = fields.sel(param="2t")[0]
accumulations_template = fields.sel(param="msl")[0]

if accumulations_shape is None:
accumulations_shape = accumulations_template.shape
Expand Down
12 changes: 12 additions & 0 deletions src/ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def write(self, data, *args, check=False, **kwargs):
raise ValueError(f"NaN values found in field. args={args} kwargs={kwargs}")
if np.isinf(data).any():
raise ValueError(f"Infinite values found in field. args={args} kwargs={kwargs}")

options = {}
options.update(self.grib_keys)
options.update(kwargs)
LOG.error("Failed to write data to %s %s", args, options)
cmd = []
for k, v in options.items():
if isinstance(v, (int, str, float)):
cmd.append("%s=%s" % (k, v))

LOG.error("grib_set -s%s", ",".join(cmd))

raise

if check:
Expand Down

0 comments on commit e5ee8d3

Please sign in to comment.