Skip to content

Commit

Permalink
download constants
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 14, 2024
1 parent baba519 commit 1351d58
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
44 changes: 38 additions & 6 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import datetime
import itertools
import logging
import os

import earthkit.data as ekd
from earthkit.data.indexing.fieldlist import FieldArray
from multiurl import download

from .base import RequestBasedInput
from .compute import make_z_from_gh
Expand All @@ -26,6 +28,8 @@
"slor",
)

CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants.grib2"


class OpenDataInput(RequestBasedInput):
WHERE = "OPENDATA"
Expand Down Expand Up @@ -102,14 +106,30 @@ def sfc_load_source(self, **kwargs):
if constant_params:
if len(constant_params) == 1:
logging.warning(
f"Single level parameter '{constant_params[0]}' is not available in open data, using constants.grib2 instead"
f"Single level parameter '{constant_params[0]}' is"
" not available in open data, using constants.grib2 instead"
)
else:
logging.warning(
f"Single level parameters {constant_params} are not available in open data, using constants.grib2 instead"
f"Single level parameters {constant_params} are"
" not available in open data, using constants.grib2 instead"
)
constants = []
ds = ekd.from_source("file", "constants.grib2")

cachedir = os.path.expanduser("~/.cache/ai-models")
basename = os.path.basename(CONSTANTS_URL)

if not os.path.exists(cachedir):
os.makedirs(cachedir)

path = os.path.join(cachedir, basename)

if not os.path.exists(path):
logging.info("Downloading %s to %s", CONSTANTS_URL, path)
download(CONSTANTS_URL, path + ".tmp")
os.rename(path + ".tmp", path)

ds = ekd.from_source("file", path)
ds = ds.sel(param=constant_params)

date = int(kwargs["date"])
Expand All @@ -125,7 +145,13 @@ def sfc_load_source(self, **kwargs):

# assert False, (date, time, step)
constants.append(
NewMetadataField(f, valid_datetime=str(valid), date=date, time="%4d" % (time,), step=step)
NewMetadataField(
f,
valid_datetime=str(valid),
date=date,
time="%4d" % (time,),
step=step,
)
)

constants = FieldArray(constants)
Expand All @@ -134,7 +160,13 @@ def sfc_load_source(self, **kwargs):

logging.debug("load source ecmwf-open-data %s", kwargs)

return self.check_sfc(pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants), request)
fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants)

# Fix grib2/eccodes bug

fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields])

return self.check_sfc(fields, request)

def ml_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
Expand All @@ -157,7 +189,7 @@ def check_ml(self, ds, request):
return ds

def _check(self, ds, what, request, *keys):
print("CHECKING", what)

expected = set()
for p in itertools.product(*[request[key] for key in keys]):
expected.add(p)
Expand Down
8 changes: 4 additions & 4 deletions src/ai_models/inputs/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __getattr__(self, name):
def __repr__(self) -> str:
return repr(self._field)

def metadata(self, name, **kwargs):
if name in self._metadata:
return self._metadata[name]
return self._field.metadata(name, **kwargs)
def metadata(self, *args, **kwargs):
if len(args) == 1 and args[0] in self._metadata:
return self._metadata[args[0]]
return self._field.metadata(*args, **kwargs)

0 comments on commit 1351d58

Please sign in to comment.