Skip to content

Commit

Permalink
added automatic file_format determination based on filename
Browse files Browse the repository at this point in the history
  • Loading branch information
wolearyc committed Oct 11, 2024
1 parent 27dde49 commit 4ef7244
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ These generic functions are less flexible than those first mentioned, and theref
model = rn.pmodel.InterpolationModel(...)
model.add_dof_from_files(..., file_format = "outcar")
:meth:`.InterpolationModel.add_dof_from_files` and other methods like it rely on these generic methods, as apparent from the ``file_format`` argument.
:meth:`.InterpolationModel.add_dof_from_files` and other methods like it rely on these generic methods, as apparent from the ``file_format`` argument. When ``file_format=="auto"``, ramannoodle will guess the file format based on the filename. Otherwise, ``file_format`` should be set according to the next section.

.. _Supported formats:

Expand Down
104 changes: 78 additions & 26 deletions ramannoodle/io/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ramannoodle.structure._reference import ReferenceStructure
from ramannoodle.exceptions import UserError, get_torch_missing_error
import ramannoodle.io.vasp as vasp_io
from ramannoodle.io._utils import pathify_as_list

TORCH_PRESENT = True
try:
Expand Down Expand Up @@ -66,15 +67,44 @@
_TRAJECTORY_WRITERS = {"xdatcar": vasp_io.xdatcar.write_trajectory}


def read_phonons(filepath: str | Path, file_format: str) -> Phonons:
def _process_file_format(filepaths: str | Path | list[str] | list[Path]) -> str:
"""Guess file format from filepath(s).
The implementation is very simple. It simply looks for relevant
strings in the filename. In the case of multiple files, the first
file is used to guess the file format.
Parameters
----------
filepath
Returns
-------
:
Lowercase file_format string.
"""
filepaths = pathify_as_list(filepaths)
filename = filepaths[0].name.lower()
if "outcar" in filename:
return "outcar"
if "vasprun.xml" in filename or "vasprun" in filename:
return "vasprun.xml"
if "xdatcar" in filename:
return "xdatcar"
if "poscar" in filename:
return "poscar"
raise ValueError(f"could not guess file format: {filepaths[0]}")


def read_phonons(filepath: str | Path, file_format: str = "auto") -> Phonons:
"""Read phonons from a file.
Parameters
----------
filepath
file_format
Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not
case sensitive.
Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported
formats`). Not case sensitive.
Returns
-------
Expand All @@ -87,22 +117,24 @@ def read_phonons(filepath: str | Path, file_format: str) -> Phonons:
InvalidFileException
Invalid file.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
return _PHONON_READERS[file_format.lower()](filepath)
except KeyError as exc:
raise ValueError(f"unsupported format: {file_format}") from exc


def read_trajectory(filepath: str | Path, file_format: str) -> Trajectory:
def read_trajectory(filepath: str | Path, file_format: str = "auto") -> Trajectory:
"""Read molecular dynamics trajectory from a file.
Parameters
----------
filepath
file_format
Supports ``"outcar"``, ``"vasprun.xml"``, (see :ref:`Supported formats`). Not
case sensitive. Use :func:`.vasp.xdatcar.read_trajectory` to read a trajectory
from an XDATCAR.
Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported
formats`). Not case sensitive. Use :func:`.vasp.xdatcar.read_trajectory` to
read a trajectory from an XDATCAR.
Returns
-------
Expand All @@ -115,6 +147,8 @@ def read_trajectory(filepath: str | Path, file_format: str) -> Trajectory:
InvalidFileException
Invalid file.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
return _TRAJECTORY_READERS[file_format.lower()](filepath)
except KeyError as exc:
Expand All @@ -127,16 +161,16 @@ def read_trajectory(filepath: str | Path, file_format: str) -> Trajectory:

def read_positions_and_polarizability(
filepath: str | Path,
file_format: str,
file_format: str = "auto",
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Read fractional positions and polarizability from a file.
Parameters
----------
filepath
file_format
Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not
case sensitive.
Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported
formats`). Not case sensitive.
Returns
-------
Expand All @@ -152,6 +186,8 @@ def read_positions_and_polarizability(
InvalidFileException
Invalid file.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
return _POSITION_AND_POLARIZABILITY_READERS[file_format.lower()](filepath)
except KeyError as exc:
Expand All @@ -160,16 +196,16 @@ def read_positions_and_polarizability(

def read_structure_and_polarizability(
filepath: str | Path,
file_format: str,
file_format: str = "auto",
) -> tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]]:
"""Read lattice, atomic numbers, fractional positions, polarizability from a file.
Parameters
----------
filepath
file_format
Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not
case sensitive.
Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported
formats`). Not case sensitive.
Returns
-------
Expand All @@ -187,6 +223,8 @@ def read_structure_and_polarizability(
InvalidFileException
Invalid file.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
return _STRUCTURE_AND_POLARIZABILITY_READERS[file_format.lower()](filepath)
except KeyError as exc:
Expand All @@ -195,16 +233,16 @@ def read_structure_and_polarizability(

def read_polarizability_dataset(
filepaths: str | Path | list[str] | list[Path],
file_format: str,
file_format: str = "auto",
) -> "PolarizabilityDataset":
"""Read polarizability dataset from files.
Parameters
----------
filepaths
file_format
Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not
case sensitive.
Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported
formats`). Not case sensitive.
Returns
-------
Expand All @@ -221,6 +259,8 @@ def read_polarizability_dataset(
"""
if not TORCH_PRESENT:
raise get_torch_missing_error()
if file_format.lower() == "auto":
file_format = _process_file_format(filepaths)
try:
return _POLARIZABILITY_DATASET_READERS[file_format.lower()](filepaths)
except KeyError as exc:
Expand All @@ -229,16 +269,16 @@ def read_polarizability_dataset(

def read_positions(
filepath: str | Path,
file_format: str,
file_format: str = "auto",
) -> NDArray[np.float64]:
"""Read fractional positions from a file.
Parameters
----------
filepath
file_format
Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"`` (see
:ref:`Supported formats`). Not case sensitive.
Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"``, and
``"auto"` (see :ref:`Supported formats`). Not case sensitive.
Returns
-------
Expand All @@ -253,21 +293,25 @@ def read_positions(
Invalid file.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
return _POSITION_READERS[file_format.lower()](filepath)
except KeyError as exc:
raise ValueError(f"unsupported format: {file_format}") from exc


def read_ref_structure(filepath: str | Path, file_format: str) -> ReferenceStructure:
def read_ref_structure(
filepath: str | Path, file_format: str = "auto"
) -> ReferenceStructure:
"""Read reference structure from a file.
Parameters
----------
filepath
file_format
Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"`` (see
:ref:`Supported formats`).
Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"``, and
``"auto"`` (see :ref:`Supported formats`).
Returns
-------
Expand All @@ -282,6 +326,8 @@ def read_ref_structure(filepath: str | Path, file_format: str) -> ReferenceStruc
SymmetryException
Structural symmetry determination failed.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
return _REFERENCE_STRUCTURE_READERS[file_format.lower()](filepath)
except KeyError as exc:
Expand All @@ -293,7 +339,7 @@ def write_structure( # pylint: disable=too-many-arguments,too-many-positional-a
atomic_numbers: list[int],
positions: NDArray[np.float64],
filepath: str | Path,
file_format: str,
file_format: str = "auto",
overwrite: bool = False,
) -> None:
"""Write structure to file.
Expand All @@ -308,7 +354,8 @@ def write_structure( # pylint: disable=too-many-arguments,too-many-positional-a
(fractional) Array with shape (N,3).
filepath
file_format
Supports ``"poscar"`` (see :ref:`Supported formats`). Not case sensitive.
Supports ``"poscar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case
sensitive.
overwrite
Overwrite the file if it exists.
label
Expand All @@ -319,6 +366,8 @@ def write_structure( # pylint: disable=too-many-arguments,too-many-positional-a
FileExistsError
File exists and ``overwrite == False``.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
_STRUCTURE_WRITERS[file_format.lower()](
lattice=lattice,
Expand All @@ -337,7 +386,7 @@ def write_trajectory(
atomic_numbers: list[int],
positions_ts: NDArray[np.float64],
filepath: str | Path,
file_format: str,
file_format: str = "auto",
overwrite: bool = False,
) -> None:
"""Write trajectory to file.
Expand All @@ -353,7 +402,8 @@ def write_trajectory(
configurations.
filepath
file_format
Supports ``"xdatcar"`` (see :ref:`Supported formats`). Not case sensitive.
Supports ``"xdatcar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case
sensitive.
overwrite
Overwrite the file if it exists.
Expand All @@ -362,6 +412,8 @@ def write_trajectory(
FileExistsError
File exists and ``overwrite == False``.
"""
if file_format.lower() == "auto":
file_format = _process_file_format(filepath)
try:
_TRAJECTORY_WRITERS[file_format.lower()](
lattice=lattice,
Expand Down
8 changes: 4 additions & 4 deletions ramannoodle/pmodel/_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def add_art(
def add_art_from_files(
self,
filepaths: str | Path | list[str] | list[Path],
file_format: str,
file_format: str = "auto",
) -> None:
"""Add an atomic Raman tensor (ART) from file(s).
Expand All @@ -211,9 +211,9 @@ def add_art_from_files(
----------
filepaths
file_format
Supports ``"outcar"`` and ``"vasprun.xml"``. If dummy model, supports
``"poscar"`` and ``"xdatcar"`` as well (see :ref:`Supported formats`). Not
case sensitive.
Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"``. If dummy model,
supports ``"poscar"`` and ``"xdatcar"`` as well (see :ref:`Supported
formats`). Not case sensitive.
Raises
------
Expand Down
4 changes: 3 additions & 1 deletion ramannoodle/pmodel/_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,9 @@ def add_dof_from_pymatgen(
)

def _read_dof(
self, filepaths: str | Path | list[str] | list[Path], file_format: str
self,
filepaths: str | Path | list[str] | list[Path],
file_format: str,
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
"""Read displacements, amplitudes, and polarizabilities from file(s).
Expand Down
10 changes: 6 additions & 4 deletions ramannoodle/structure/_displace.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def write_displaced_structures(
cart_displacement: NDArray[np.float64],
amplitudes: NDArray[np.float64],
filepaths: str | Path | list[str] | list[Path],
file_format: str,
file_format: str = "auto",
overwrite: bool = False,
) -> None:
"""Write displaced structures to files.
Expand All @@ -96,7 +96,8 @@ def write_displaced_structures(
(Å) Array with shape (M,).
filepaths
file_format
Supports ``"poscar"`` (see :ref:`Supported formats`). Not case sensitive.
Supports ``"poscar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case
sensitive.
overwrite
If ``True``, overwrite the file if it exists.
"""
Expand Down Expand Up @@ -163,7 +164,7 @@ def write_ast_displaced_structures(
cart_direction: NDArray[np.float64],
amplitudes: NDArray[np.float64],
filepaths: str | Path | list[str] | list[Path],
file_format: str,
file_format: str = "auto",
overwrite: bool = False,
) -> None:
"""Write displaced structures with a single atom displaced along a direction.
Expand All @@ -180,7 +181,8 @@ def write_ast_displaced_structures(
(Å) Array with shape (M,).
filepaths
file_format
Supports ``"poscar"`` (see :ref:`Supported formats`). Not case sensitive.
Supports ``"poscar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case
sensitive.
overwrite
Overwrite the file if it exists.
"""
Expand Down
2 changes: 1 addition & 1 deletion test/tests/test_outcar.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def test_read_trajectory_from_outcar(
path_fixture: Path, trajectory_length: int, last_position: NDArray[np.float64]
) -> None:
"""Test read_trajectory for outcar (normal)."""
trajectory = generic_io.read_trajectory(path_fixture, file_format="outcar")
trajectory = generic_io.read_trajectory(path_fixture)
assert len(trajectory) == trajectory_length
assert np.allclose(last_position, trajectory[-1][-1])
3 changes: 2 additions & 1 deletion test/tests/torch/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
],
)
def test_load_polarizability_dataset(
filepaths: str | list[str], file_format: str
filepaths: str | list[str],
file_format: str,
) -> None:
"""Test of generic load_polarizability_dataset (normal)."""
dataset = generic_io.read_polarizability_dataset(filepaths, file_format)
Expand Down

0 comments on commit 4ef7244

Please sign in to comment.