From 1351d586d2e1862645ac042b51803e6c94060bbd Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 14 Sep 2024 15:56:29 +0000 Subject: [PATCH] download constants --- src/ai_models/inputs/opendata.py | 44 ++++++++++++++++++++++++++----- src/ai_models/inputs/transform.py | 8 +++--- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index e0ec5f5..5de4ce3 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -8,9 +8,11 @@ import datetime import itertools import logging +import os import earthkit.data as ekd from earthkit.data.indexing.fieldlist import FieldArray +from multiurl import download from .base import RequestBasedInput from .compute import make_z_from_gh @@ -26,6 +28,8 @@ "slor", ) +CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants.grib2" + class OpenDataInput(RequestBasedInput): WHERE = "OPENDATA" @@ -102,14 +106,30 @@ def sfc_load_source(self, **kwargs): if constant_params: if len(constant_params) == 1: logging.warning( - f"Single level parameter '{constant_params[0]}' is not available in open data, using constants.grib2 instead" + f"Single level parameter '{constant_params[0]}' is" + " not available in open data, using constants.grib2 instead" ) else: logging.warning( - f"Single level parameters {constant_params} are not available in open data, using constants.grib2 instead" + f"Single level parameters {constant_params} are" + " not available in open data, using constants.grib2 instead" ) constants = [] - ds = ekd.from_source("file", "constants.grib2") + + cachedir = os.path.expanduser("~/.cache/ai-models") + basename = os.path.basename(CONSTANTS_URL) + + if not os.path.exists(cachedir): + os.makedirs(cachedir) + + path = os.path.join(cachedir, basename) + + if not os.path.exists(path): + logging.info("Downloading %s to %s", CONSTANTS_URL, path) + download(CONSTANTS_URL, path + ".tmp") + os.rename(path + ".tmp", path) + + ds = ekd.from_source("file", path) ds = ds.sel(param=constant_params) date = int(kwargs["date"]) @@ -125,7 +145,13 @@ def sfc_load_source(self, **kwargs): # assert False, (date, time, step) constants.append( - NewMetadataField(f, valid_datetime=str(valid), date=date, time="%4d" % (time,), step=step) + NewMetadataField( + f, + valid_datetime=str(valid), + date=date, + time="%4d" % (time,), + step=step, + ) ) constants = FieldArray(constants) @@ -134,7 +160,13 @@ def sfc_load_source(self, **kwargs): logging.debug("load source ecmwf-open-data %s", kwargs) - return self.check_sfc(pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants), request) + fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants) + + # Fix grib2/eccodes bug + + fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields]) + + return self.check_sfc(fields, request) def ml_load_source(self, **kwargs): pproc = self._adjust(kwargs) @@ -157,7 +189,7 @@ def check_ml(self, ds, request): return ds def _check(self, ds, what, request, *keys): - print("CHECKING", what) + expected = set() for p in itertools.product(*[request[key] for key in keys]): expected.add(p) diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py index a11aa62..29ee983 100644 --- a/src/ai_models/inputs/transform.py +++ b/src/ai_models/inputs/transform.py @@ -41,7 +41,7 @@ def __getattr__(self, name): def __repr__(self) -> str: return repr(self._field) - def metadata(self, name, **kwargs): - if name in self._metadata: - return self._metadata[name] - return self._field.metadata(name, **kwargs) + def metadata(self, *args, **kwargs): + if len(args) == 1 and args[0] in self._metadata: + return self._metadata[args[0]] + return self._field.metadata(*args, **kwargs)