Skip to content

Commit

Permalink
Refactor FileOutput to inherit from FileOutputBase
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Jun 18, 2024
1 parent 58fc061 commit fb1fd11
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions src/ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,31 @@ def finalise(self, *args, **kwargs):
pass


class FileOutput(Output):
class FileOutputBase(Output):
def __init__(self, owner, path, metadata, **kwargs):
self._first = True
metadata.setdefault("stream", "oper")
metadata.setdefault("expver", owner.expver)
metadata.setdefault("class", "ml")

LOG.info("Writing results to %s.", path)
self.path = path
self.owner = owner
self.metadata = metadata

@cached_property
def output(self):

def grib_keys(self):
edition = self.metadata.pop("edition", 2)

self.grib_keys = dict(
_grib_keys = dict(
edition=edition,
generatingProcessIdentifier=self.owner.version,
)
self.grib_keys.update(self.metadata)
_grib_keys.update(self.metadata)

return _grib_keys

@cached_property
def output(self):
return cml.new_grib_output(
self.path,
split_output=True,
Expand Down Expand Up @@ -91,6 +93,20 @@ def write(self, data, *args, check=False, **kwargs):
return handle, path


class FileOutput(FileOutputBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
LOG.info("Writing results to %s", self.path)


class NoneOutput(Output):
def __init__(self, *args, **kwargs):
LOG.info("Results will not be written.")

def write(self, *args, **kwargs):
pass


class HindcastReLabel:
def __init__(self, owner, output, hindcast_reference_year=None, hindcast_reference_date=None, **kwargs):
self.owner = owner
Expand Down Expand Up @@ -151,14 +167,6 @@ def write(self, *args, **kwargs):
return self.output.write(*args, **kwargs)


class NoneOutput(Output):
def __init__(self, *args, **kwargs):
LOG.info("Results will not be written.")

def write(self, *args, **kwargs):
pass


def get_output(name, owner, *args, **kwargs):
result = available_outputs()[name].load()(owner, *args, **kwargs)
if kwargs.get("hindcast_reference_year") is not None or kwargs.get("hindcast_reference_date") is not None:
Expand Down

0 comments on commit fb1fd11

Please sign in to comment.