diff --git a/docs/source/io.rst b/docs/source/io.rst index 100ccb1..baea14b 100644 --- a/docs/source/io.rst +++ b/docs/source/io.rst @@ -33,7 +33,7 @@ These generic functions are less flexible than those first mentioned, and theref model = rn.pmodel.InterpolationModel(...) model.add_dof_from_files(..., file_format = "outcar") -:meth:`.InterpolationModel.add_dof_from_files` and other methods like it rely on these generic methods, as apparent from the ``file_format`` argument. +:meth:`.InterpolationModel.add_dof_from_files` and other methods like it rely on these generic methods, as apparent from the ``file_format`` argument. When ``file_format=="auto"``, ramannoodle will guess the file format based on the filename. Otherwise, ``file_format`` should be set according to the next section. .. _Supported formats: diff --git a/ramannoodle/io/generic.py b/ramannoodle/io/generic.py index 3c81f8b..a90dcac 100644 --- a/ramannoodle/io/generic.py +++ b/ramannoodle/io/generic.py @@ -18,6 +18,7 @@ from ramannoodle.structure._reference import ReferenceStructure from ramannoodle.exceptions import UserError, get_torch_missing_error import ramannoodle.io.vasp as vasp_io +from ramannoodle.io._utils import pathify_as_list TORCH_PRESENT = True try: @@ -66,15 +67,44 @@ _TRAJECTORY_WRITERS = {"xdatcar": vasp_io.xdatcar.write_trajectory} -def read_phonons(filepath: str | Path, file_format: str) -> Phonons: +def _process_file_format(filepaths: str | Path | list[str] | list[Path]) -> str: + """Guess file format from filepath(s). + + The implementation is very simple. It simply looks for relevant + strings in the filename. In the case of multiple files, the first + file is used to guess the file format. + + Parameters + ---------- + filepath + + Returns + ------- + : + Lowercase file_format string. + """ + filepaths = pathify_as_list(filepaths) + filename = filepaths[0].name.lower() + if "outcar" in filename: + return "outcar" + if "vasprun.xml" in filename or "vasprun" in filename: + return "vasprun.xml" + if "xdatcar" in filename: + return "xdatcar" + if "poscar" in filename: + return "poscar" + raise ValueError(f"could not guess file format: {filepaths[0]}") + + +def read_phonons(filepath: str | Path, file_format: str = "auto") -> Phonons: """Read phonons from a file. Parameters ---------- filepath file_format - Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not - case sensitive. + Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported + formats`). Not case sensitive. Returns ------- @@ -87,22 +117,24 @@ def read_phonons(filepath: str | Path, file_format: str) -> Phonons: InvalidFileException Invalid file. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: return _PHONON_READERS[file_format.lower()](filepath) except KeyError as exc: raise ValueError(f"unsupported format: {file_format}") from exc -def read_trajectory(filepath: str | Path, file_format: str) -> Trajectory: +def read_trajectory(filepath: str | Path, file_format: str = "auto") -> Trajectory: """Read molecular dynamics trajectory from a file. Parameters ---------- filepath file_format - Supports ``"outcar"``, ``"vasprun.xml"``, (see :ref:`Supported formats`). Not - case sensitive. Use :func:`.vasp.xdatcar.read_trajectory` to read a trajectory - from an XDATCAR. + Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported + formats`). Not case sensitive. Use :func:`.vasp.xdatcar.read_trajectory` to + read a trajectory from an XDATCAR. Returns ------- @@ -115,6 +147,8 @@ def read_trajectory(filepath: str | Path, file_format: str) -> Trajectory: InvalidFileException Invalid file. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: return _TRAJECTORY_READERS[file_format.lower()](filepath) except KeyError as exc: @@ -127,7 +161,7 @@ def read_trajectory(filepath: str | Path, file_format: str) -> Trajectory: def read_positions_and_polarizability( filepath: str | Path, - file_format: str, + file_format: str = "auto", ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Read fractional positions and polarizability from a file. @@ -135,8 +169,8 @@ def read_positions_and_polarizability( ---------- filepath file_format - Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not - case sensitive. + Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported + formats`). Not case sensitive. Returns ------- @@ -152,6 +186,8 @@ def read_positions_and_polarizability( InvalidFileException Invalid file. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: return _POSITION_AND_POLARIZABILITY_READERS[file_format.lower()](filepath) except KeyError as exc: @@ -160,7 +196,7 @@ def read_positions_and_polarizability( def read_structure_and_polarizability( filepath: str | Path, - file_format: str, + file_format: str = "auto", ) -> tuple[NDArray[np.float64], list[int], NDArray[np.float64], NDArray[np.float64]]: """Read lattice, atomic numbers, fractional positions, polarizability from a file. @@ -168,8 +204,8 @@ def read_structure_and_polarizability( ---------- filepath file_format - Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not - case sensitive. + Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported + formats`). Not case sensitive. Returns ------- @@ -187,6 +223,8 @@ def read_structure_and_polarizability( InvalidFileException Invalid file. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: return _STRUCTURE_AND_POLARIZABILITY_READERS[file_format.lower()](filepath) except KeyError as exc: @@ -195,7 +233,7 @@ def read_structure_and_polarizability( def read_polarizability_dataset( filepaths: str | Path | list[str] | list[Path], - file_format: str, + file_format: str = "auto", ) -> "PolarizabilityDataset": """Read polarizability dataset from files. @@ -203,8 +241,8 @@ def read_polarizability_dataset( ---------- filepaths file_format - Supports ``"outcar"``, ``"vasprun.xml"`` (see :ref:`Supported formats`). Not - case sensitive. + Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"`` (see :ref:`Supported + formats`). Not case sensitive. Returns ------- @@ -221,6 +259,8 @@ def read_polarizability_dataset( """ if not TORCH_PRESENT: raise get_torch_missing_error() + if file_format.lower() == "auto": + file_format = _process_file_format(filepaths) try: return _POLARIZABILITY_DATASET_READERS[file_format.lower()](filepaths) except KeyError as exc: @@ -229,7 +269,7 @@ def read_polarizability_dataset( def read_positions( filepath: str | Path, - file_format: str, + file_format: str = "auto", ) -> NDArray[np.float64]: """Read fractional positions from a file. @@ -237,8 +277,8 @@ def read_positions( ---------- filepath file_format - Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"`` (see - :ref:`Supported formats`). Not case sensitive. + Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"``, and + ``"auto"` (see :ref:`Supported formats`). Not case sensitive. Returns ------- @@ -253,21 +293,25 @@ def read_positions( Invalid file. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: return _POSITION_READERS[file_format.lower()](filepath) except KeyError as exc: raise ValueError(f"unsupported format: {file_format}") from exc -def read_ref_structure(filepath: str | Path, file_format: str) -> ReferenceStructure: +def read_ref_structure( + filepath: str | Path, file_format: str = "auto" +) -> ReferenceStructure: """Read reference structure from a file. Parameters ---------- filepath file_format - Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"`` (see - :ref:`Supported formats`). + Supports ``"outcar"``, ``"poscar"``, ``"xdatcar"``, ``"vasprun.xml"``, and + ``"auto"`` (see :ref:`Supported formats`). Returns ------- @@ -282,6 +326,8 @@ def read_ref_structure(filepath: str | Path, file_format: str) -> ReferenceStruc SymmetryException Structural symmetry determination failed. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: return _REFERENCE_STRUCTURE_READERS[file_format.lower()](filepath) except KeyError as exc: @@ -293,7 +339,7 @@ def write_structure( # pylint: disable=too-many-arguments,too-many-positional-a atomic_numbers: list[int], positions: NDArray[np.float64], filepath: str | Path, - file_format: str, + file_format: str = "auto", overwrite: bool = False, ) -> None: """Write structure to file. @@ -308,7 +354,8 @@ def write_structure( # pylint: disable=too-many-arguments,too-many-positional-a (fractional) Array with shape (N,3). filepath file_format - Supports ``"poscar"`` (see :ref:`Supported formats`). Not case sensitive. + Supports ``"poscar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case + sensitive. overwrite Overwrite the file if it exists. label @@ -319,6 +366,8 @@ def write_structure( # pylint: disable=too-many-arguments,too-many-positional-a FileExistsError File exists and ``overwrite == False``. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: _STRUCTURE_WRITERS[file_format.lower()]( lattice=lattice, @@ -337,7 +386,7 @@ def write_trajectory( atomic_numbers: list[int], positions_ts: NDArray[np.float64], filepath: str | Path, - file_format: str, + file_format: str = "auto", overwrite: bool = False, ) -> None: """Write trajectory to file. @@ -353,7 +402,8 @@ def write_trajectory( configurations. filepath file_format - Supports ``"xdatcar"`` (see :ref:`Supported formats`). Not case sensitive. + Supports ``"xdatcar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case + sensitive. overwrite Overwrite the file if it exists. @@ -362,6 +412,8 @@ def write_trajectory( FileExistsError File exists and ``overwrite == False``. """ + if file_format.lower() == "auto": + file_format = _process_file_format(filepath) try: _TRAJECTORY_WRITERS[file_format.lower()]( lattice=lattice, diff --git a/ramannoodle/pmodel/_art.py b/ramannoodle/pmodel/_art.py index 1162c6b..b3f94b8 100644 --- a/ramannoodle/pmodel/_art.py +++ b/ramannoodle/pmodel/_art.py @@ -199,7 +199,7 @@ def add_art( def add_art_from_files( self, filepaths: str | Path | list[str] | list[Path], - file_format: str, + file_format: str = "auto", ) -> None: """Add an atomic Raman tensor (ART) from file(s). @@ -211,9 +211,9 @@ def add_art_from_files( ---------- filepaths file_format - Supports ``"outcar"`` and ``"vasprun.xml"``. If dummy model, supports - ``"poscar"`` and ``"xdatcar"`` as well (see :ref:`Supported formats`). Not - case sensitive. + Supports ``"outcar"``, ``"vasprun.xml"``, and ``"auto"``. If dummy model, + supports ``"poscar"`` and ``"xdatcar"`` as well (see :ref:`Supported + formats`). Not case sensitive. Raises ------ diff --git a/ramannoodle/pmodel/_interpolation.py b/ramannoodle/pmodel/_interpolation.py index 9d1a434..3a405e5 100644 --- a/ramannoodle/pmodel/_interpolation.py +++ b/ramannoodle/pmodel/_interpolation.py @@ -624,7 +624,9 @@ def add_dof_from_pymatgen( ) def _read_dof( - self, filepaths: str | Path | list[str] | list[Path], file_format: str + self, + filepaths: str | Path | list[str] | list[Path], + file_format: str, ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: """Read displacements, amplitudes, and polarizabilities from file(s). diff --git a/ramannoodle/structure/_displace.py b/ramannoodle/structure/_displace.py index ff446a2..7ee13b0 100644 --- a/ramannoodle/structure/_displace.py +++ b/ramannoodle/structure/_displace.py @@ -80,7 +80,7 @@ def write_displaced_structures( cart_displacement: NDArray[np.float64], amplitudes: NDArray[np.float64], filepaths: str | Path | list[str] | list[Path], - file_format: str, + file_format: str = "auto", overwrite: bool = False, ) -> None: """Write displaced structures to files. @@ -96,7 +96,8 @@ def write_displaced_structures( (Å) Array with shape (M,). filepaths file_format - Supports ``"poscar"`` (see :ref:`Supported formats`). Not case sensitive. + Supports ``"poscar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case + sensitive. overwrite If ``True``, overwrite the file if it exists. """ @@ -163,7 +164,7 @@ def write_ast_displaced_structures( cart_direction: NDArray[np.float64], amplitudes: NDArray[np.float64], filepaths: str | Path | list[str] | list[Path], - file_format: str, + file_format: str = "auto", overwrite: bool = False, ) -> None: """Write displaced structures with a single atom displaced along a direction. @@ -180,7 +181,8 @@ def write_ast_displaced_structures( (Å) Array with shape (M,). filepaths file_format - Supports ``"poscar"`` (see :ref:`Supported formats`). Not case sensitive. + Supports ``"poscar"`` and ``"auto"`` (see :ref:`Supported formats`). Not case + sensitive. overwrite Overwrite the file if it exists. """ diff --git a/test/tests/test_outcar.py b/test/tests/test_outcar.py index ef1be0c..99c0e9c 100644 --- a/test/tests/test_outcar.py +++ b/test/tests/test_outcar.py @@ -89,6 +89,6 @@ def test_read_trajectory_from_outcar( path_fixture: Path, trajectory_length: int, last_position: NDArray[np.float64] ) -> None: """Test read_trajectory for outcar (normal).""" - trajectory = generic_io.read_trajectory(path_fixture, file_format="outcar") + trajectory = generic_io.read_trajectory(path_fixture) assert len(trajectory) == trajectory_length assert np.allclose(last_position, trajectory[-1][-1]) diff --git a/test/tests/torch/test_dataset.py b/test/tests/torch/test_dataset.py index 42961a8..18f67e2 100644 --- a/test/tests/torch/test_dataset.py +++ b/test/tests/torch/test_dataset.py @@ -27,7 +27,8 @@ ], ) def test_load_polarizability_dataset( - filepaths: str | list[str], file_format: str + filepaths: str | list[str], + file_format: str, ) -> None: """Test of generic load_polarizability_dataset (normal).""" dataset = generic_io.read_polarizability_dataset(filepaths, file_format)