Source code for scitacean.transfer.sftp

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025 SciCat Project (https://github.com/SciCatProject/scitacean)
"""SFTP file transfer."""

import os
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path

from paramiko import SFTPAttributes, SFTPClient, SSHClient

from ..dataset import Dataset
from ..error import FileNotAccessibleError, FileUploadError
from ..file import File
from ..filesystem import RemotePath
from ..logging import get_logger
from ..util.credentials import SecretStr, StrStorage
from ._util import source_folder_for


class SFTPDownloadConnection:
    """Connection for downloading files with SFTP.

    Should be created using
    :meth:`scitacean.transfer.sftp.SFTPFileTransfer.connect_for_download`.
    """

    def __init__(self, *, sftp_client: SFTPClient, host: str) -> None:
        self._sftp_client = sftp_client
        self._host = host

    def download_files(self, *, remote: list[RemotePath], local: list[Path]) -> None:
        """Download files from the given remote path."""
        for r, l in zip(remote, local, strict=True):
            self.download_file(remote=r, local=l)

    def download_file(self, *, remote: RemotePath, local: Path) -> None:
        """Download a file from the given remote path."""
        get_logger().info(
            "Downloading file %s from host %s to %s",
            remote,
            self._host,
            local,
        )
        try:
            self._sftp_client.get(remotepath=remote.posix, localpath=os.fspath(local))
        except FileNotFoundError:
            raise FileNotAccessibleError(
                f"File {remote} not found on SFTP host {self._host}", remote_path=remote
            ) from None


class SFTPUploadConnection:
    """Connection for uploading files with SFTP.

    Should be created using
    :meth:`scitacean.transfer.sftp.SFTPFileTransfer.connect_for_upload`.
    """

    def __init__(
        self, *, sftp_client: SFTPClient, source_folder: RemotePath, host: str
    ) -> None:
        self._sftp_client = sftp_client
        self._source_folder = source_folder
        self._host = host

    @property
    def source_folder(self) -> RemotePath:
        """The source folder this connection uploads to."""
        return self._source_folder

    def remote_path(self, filename: str | RemotePath) -> RemotePath:
        """Return the complete remote path for a given path."""
        return self.source_folder / filename

    def _make_source_folder(self) -> None:
        try:
            _mkdir_remote(self._sftp_client, self.source_folder)
        except OSError as exc:
            raise FileUploadError(
                f"Failed to create source folder {self.source_folder}: {exc.args}"
            ) from None

    def upload_files(self, *files: File) -> list[File]:
        """Upload files to the remote folder."""
        self._make_source_folder()
        uploaded: list[File] = []
        try:
            uploaded.extend(self._upload_file(file) for file in files)
        except Exception:
            self.revert_upload(*uploaded)
            raise
        return uploaded

    def _upload_file(self, file: File) -> File:
        if file.local_path is None:
            raise ValueError(
                f"Cannot upload file to {file.remote_path}, the file has no local path"
            )
        remote_path = self.remote_path(file.remote_path)
        get_logger().info(
            "Uploading file %s to %s on host %s",
            file.local_path,
            remote_path,
            self._host,
        )
        st = self._sftp_client.put(
            remotepath=remote_path.posix, localpath=os.fspath(file.local_path)
        )
        return file.uploaded(
            remote_gid=str(st.st_gid),
            remote_uid=str(st.st_uid),
            remote_creation_time=datetime.now().astimezone(timezone.utc),
            remote_perm=str(st.st_mode),
            remote_size=st.st_size,
        )

    def revert_upload(self, *files: File) -> None:
        """Remove uploaded files from the remote folder."""
        for file in files:
            self._revert_upload_single(remote=file.remote_path, local=file.local_path)

        if _remote_folder_is_empty(self._sftp_client, self.source_folder):
            try:
                get_logger().info(
                    "Removing empty remote directory %s on host %s",
                    self.source_folder,
                    self._host,
                )
                self._sftp_client.rmdir(self.source_folder.posix)
            except OSError as exc:
                get_logger().warning(
                    "Failed to remove empty remote directory %s on host %s:\n%s",
                    self.source_folder,
                    self._host,
                    exc,
                )

    def _revert_upload_single(self, *, remote: RemotePath, local: Path | None) -> None:
        remote_path = self.remote_path(remote)
        get_logger().info(
            "Reverting upload of file %s to %s on host %s",
            local,
            remote_path,
            self._host,
        )

        try:
            self._sftp_client.remove(remote_path.posix)
        except OSError as exc:
            get_logger().warning("Error reverting file %s:\n%s", remote_path, exc)
            return


[docs] class SFTPFileTransfer: """Upload / download files using SFTP. Configuration & Authentication ------------------------------ The file transfer connects to the server at the address given as the ``host`` constructor argument. This may be - a full url such as ``"some.fileserver.edu"``, - or an IP address like ``"127.0.0.1"``. The file transfer relies on :class:`paramiko.client.SSHClient` for authentication and arguments are passed along to the constructor of ``SSHClient``. See its documentation for details. ``SFTPFileTransfer`` can use an SSH agent if one is configured or use explicitly provided username and password or a key file. If none of these options work, you can define a custom ``connect`` function which creates a :class:`paramiko.sftp_client.SFTPClient`. See the examples below. Upload folder ------------- The file transfer can take an optional ``source_folder`` as a constructor argument. If it is given, ``SFTPFileTransfer`` uploads all files to it and ignores the source folder set in the dataset. If it is not given, ``SFTPFileTransfer`` uses the dataset's source folder. The source folder argument to ``SFTPFileTransfer`` may be a Python format string. In that case, all format fields are replaced by the corresponding fields of the dataset. All non-ASCII characters and most special ASCII characters are replaced. This should avoid broken paths from essentially random contents in datasets. Examples -------- Given .. code-block:: python dset = Dataset( type="raw", name="my-dataset", source_folder="/dataset/source", ) This uploads to ``/dataset/source``: .. code-block:: python file_transfer = SFTPFileTransfer(host="fileserver") This uploads to ``/transfer/folder``: .. code-block:: python file_transfer = SFTPFileTransfer(host="fileserver", source_folder="transfer/folder") This uploads to ``/transfer/my-dataset``: (Note that ``{name}`` is replaced by ``dset.name``.) .. code-block:: python file_transfer = SFTPFileTransfer(host="fileserver", source_folder="transfer/{name}") A useful approach is to include a unique ID in the source folder, for example, ``"/some/base/folder/{uid}"``, to avoid clashes between different datasets. Scitacean will fill in the ``"{uid}"`` placeholder with a new UUID4. The connection and authentication method can be customized using the ``connect`` argument. For example, to use a specific username + SSH key file, use the following: .. code-block:: python def connect(host, port): from paramiko import SSHClient client = SSHClient() client.load_system_host_keys() client.connect( hostname=host, port=port, username="<username>", key_filename="<key-file-name>", ) return client.open_sftp() file_transfer = SFTPFileTransfer(host="fileserver", connect=connect) The :class:`paramiko.client.SSHClient` can be configured as needed in this function. """
[docs] def __init__( self, *, host: str, port: int = 22, username: str | None = None, password: str | StrStorage | None = None, key_filename: str | None = None, source_folder: str | RemotePath | None = None, connect: Callable[[str, int | None], SFTPClient] | None = None, ) -> None: """Construct a new SFTP file transfer. Parameters ---------- host: URL or name of the server to connect to. port: Port of the server. username: Username for the server. password: Password for the user. Or passphrase for the private key, if ``key_filename`` is provided. key_filename: Path to a private key file for authentication. source_folder: Upload files to this folder if set. Otherwise, upload to the dataset's source_folder. Ignored when downloading files. connect: If this argument is set, it will be called to create a client for the server instead of the builtin method. The function arguments are ``host`` and ``port`` as determined by the arguments to ``__init__`` shown above. """ self._host = host self._port = port self._username = username self._password = SecretStr(password) if isinstance(password, str) else password self._key_filename = key_filename self._source_folder_pattern = ( RemotePath(source_folder) if source_folder is not None else None ) self._connect = connect
[docs] def source_folder_for(self, dataset: Dataset) -> RemotePath: """Return the source folder used for the given dataset.""" return source_folder_for(dataset, self._source_folder_pattern)
[docs] @contextmanager def connect_for_download( self, dataset: Dataset, representative_file_path: RemotePath ) -> Iterator[SFTPDownloadConnection]: """Create a connection for downloads, use as a context manager. Parameters ---------- dataset: The connection will be used to download files of this dataset. representative_file_path: A path on the SFTP host to check whether files for this dataset can be read. The transfer assumes that, if it is possible to read from this path, it is possible to read from the paths of all files to be downloaded. Returns ------- : An open :class:`SFTPDownloadConnection` object. """ sftp_client = _connect( self._host, self._port, self._username, self._password, self._key_filename, connect=self._connect, ) try: # Check if the representative file can be read, an exception means that # transfer cannot be used for this file. test_path = self.source_folder_for(dataset) / representative_file_path _ = sftp_client.stat(test_path.posix) yield SFTPDownloadConnection(sftp_client=sftp_client, host=self._host) finally: sftp_client.close()
[docs] @contextmanager def connect_for_upload( self, dataset: Dataset, representative_file_path: RemotePath ) -> Iterator[SFTPUploadConnection]: """Create a connection for uploads, use as a context manager. Parameters ---------- dataset: The connection will be used to upload files of this dataset. Used to determine the target folder. representative_file_path: This is not used by :class:`SFTPFileTransfer`. The transfer assumes that all paths are writable when connecting. The actual upload fails if the user lacks sufficient permissions. Returns ------- : An open :class:`SFTPUploadConnection` object. """ source_folder = self.source_folder_for(dataset) sftp_client = _connect( self._host, self._port, self._username, self._password, self._key_filename, connect=self._connect, ) try: yield SFTPUploadConnection( sftp_client=sftp_client, source_folder=source_folder, host=self._host ) finally: sftp_client.close()
def _default_connect( host: str, port: int | None, username: str | None, password: StrStorage | None, key_filename: str | None, ) -> SFTPClient: client = SSHClient() client.load_system_host_keys() args = { "hostname": host, "port": port, "username": username, "password": None if password is None else password.get_str(), "key_filename": key_filename, } args = {name: value for name, value in args.items() if value is not None} client.connect(**args) # type: ignore[arg-type] return client.open_sftp() def _connect( host: str, port: int, username: str | None, password: StrStorage | None, key_filename: str | None, connect: Callable[[str, int | None], SFTPClient] | None, ) -> SFTPClient: try: if connect is None: return _default_connect(host, port, username, password, key_filename) return connect(host, port) except Exception as exception: new_exception = type(exception)(exception.args) if "known_host" in new_exception.args[0]: new_exception.__notes__ = [ "You may have to connect to the server using a different method first " "and accept the server's host key. E.g., in a terminal, run " f"`ssh {host}` (you may need to specify a username and port.)." ] # We pass secrets as arguments to functions called in this block, and those # can be leaked through exception handlers. So catch all exceptions # and strip the backtrace up to this point to hide those secrets. raise new_exception from None def _remote_folder_is_empty(sftp: SFTPClient, path: RemotePath) -> bool: return not sftp.listdir(path.posix) def _mkdir_remote(sftp: SFTPClient, path: RemotePath) -> None: if (parent := path.parent) not in (".", "/"): _mkdir_remote(sftp, parent) st_stat = _try_remote_stat(sftp, path) if st_stat is None: sftp.mkdir(path.posix) elif not _is_remote_dir(st_stat): raise FileExistsError( f"Cannot make directory because path points to a file: {path}" ) def _try_remote_stat(sftp: SFTPClient, path: RemotePath) -> SFTPAttributes | None: try: return sftp.stat(path.posix) except FileNotFoundError: return None def _is_remote_dir(st_stat: SFTPAttributes) -> bool: if st_stat.st_mode is None: return True # Assume it is a dir and let downstream code fail if it isn't. return st_stat.st_mode & 0o040000 == 0o040000 __all__ = ["SFTPDownloadConnection", "SFTPFileTransfer", "SFTPUploadConnection"]