Source code for torch.distributed.elastic.timer.file_based_local_timer
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import io
import json
import os
import select
import signal
import sys
import threading
import time
from typing import Callable, Dict, List, Optional, Set, Tuple
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
from torch.distributed.elastic.timer.debug_info_logging import (
log_debug_info_for_expired_timers,
)
from torch.distributed.elastic.utils.logging import get_logger
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
logger = get_logger(__name__)
def _retry(max_retries: int, sleep_time: float) -> Callable:
"""
A simple retry wrapper.
Args:
max_retries: int, the maximum number of retries.
sleep_time: float, the time to sleep between retries.
"""
def wrapper(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
for i in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
logger.exception("Error running %s. Retrying...", func.__name__)
if i < max_retries - 1:
time.sleep(sleep_time)
else:
raise
return wrapper
return wrapper
class FileTimerRequest(TimerRequest):
"""
Data object representing a countdown timer acquisition and release
that is used between the ``FileTimerClient`` and ``FileTimerServer``.
A negative ``expiration_time`` should be interpreted as a "release"
request.
``signal`` is the signal to reap the worker process from the server
process.
"""
__slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"]
def __init__(
self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0
) -> None:
self.version = 1
self.worker_pid = worker_pid
self.scope_id = scope_id
self.expiration_time = expiration_time
self.signal = signal
def __eq__(self, other) -> bool:
if isinstance(other, FileTimerRequest):
return (
self.version == other.version
and self.worker_pid == other.worker_pid
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
and self.signal == other.signal
)
return False
def to_json(self) -> str:
return json.dumps(
{
"version": self.version,
"pid": self.worker_pid,
"scope_id": self.scope_id,
"expiration_time": self.expiration_time,
"signal": self.signal,
},
)
[docs]class FileTimerClient(TimerClient):
"""
Client side of ``FileTimerServer``. This client is meant to be used
on the same host that the ``FileTimerServer`` is running on and uses
pid to uniquely identify a worker.
This client uses a named_pipe to send timer requests to the
``FileTimerServer``. This client is a producer while the
``FileTimerServer`` is a consumer. Multiple clients can work with
the same ``FileTimerServer``.
Args:
file_path: str, the path of a FIFO special file. ``FileTimerServer``
must have created it by calling os.mkfifo().
signal: signal, the signal to use to kill the process. Using a
negative or zero signal will not kill the process.
"""
def __init__(
self,
file_path: str,
signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined]
) -> None:
super().__init__()
self._file_path = file_path
self.signal = signal
@_retry(max_retries=10, sleep_time=0.1)
def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
# The server may have crashed or may haven't started yet.
# In such case, calling open() in blocking model blocks the client.
# To avoid such issue, open it in non-blocking mode, and an OSError will
# be raised if the server is not there.
fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK)
return os.fdopen(fd, "wt")
def _send_request(self, request: FileTimerRequest) -> None:
try:
file = self._open_non_blocking()
except Exception as e:
raise BrokenPipeError(
"Could not send the FileTimerRequest because FileTimerServer is not available."
) from e
with file:
json_request = request.to_json()
# Write request with no greater than select.PIPE_BUF is guarantee to be atomic.
if len(json_request) > select.PIPE_BUF:
raise RuntimeError(
f"FileTimerRequest larger than {select.PIPE_BUF} bytes "
f"is not supported: {json_request}"
)
file.write(json_request + "\n")
def acquire(self, scope_id: str, expiration_time: float) -> None:
self._send_request(
request=FileTimerRequest(
worker_pid=os.getpid(),
scope_id=scope_id,
expiration_time=expiration_time,
signal=self.signal,
),
)
def release(self, scope_id: str) -> None:
self._send_request(
request=FileTimerRequest(
worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0
),
)
[docs]class FileTimerServer:
"""
Server that works with ``FileTimerClient``. Clients are expected to be
running on the same host as the process that is running this server.
Each host in the job is expected to start its own timer server locally
and each server instance manages timers for local workers (running on
processes on the same host).
Args:
file_path: str, the path of a FIFO special file to be created.
max_interval: float, max interval in seconds for each watchdog loop.
daemon: bool, running the watchdog thread in daemon mode or not.
A daemon thread will not block a process to stop.
log_event: Callable[[Dict[str, str]], None], an optional callback for
logging the events in JSON format.
"""
def __init__(
self,
file_path: str,
run_id: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None,
) -> None:
self._file_path = file_path
self._run_id = run_id
self._max_interval = max_interval
self._daemon = daemon
self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
self._stop_signaled = False
self._watchdog_thread: Optional[threading.Thread] = None
self._is_client_started = False
if os.path.exists(self._file_path):
os.remove(self._file_path)
os.mkfifo(self._file_path)
# For test only. Count the number of requests received.
self._request_count = 0
# For test only. Process all requests and stop the server.
self._run_once = False
self._log_event = (
log_event if log_event is not None else lambda name, request: None
)
self._last_progress_time = int(time.time())
def start(self) -> None:
logger.info(
"Starting %s... max_interval=%s, daemon=%s, file_path=%s",
type(self).__name__,
self._max_interval,
self._daemon,
self._file_path,
)
self._watchdog_thread = threading.Thread(
target=self._watchdog_loop, daemon=self._daemon
)
logger.info("Starting watchdog thread...")
self._watchdog_thread.start()
self._log_event("watchdog started", None)
def stop(self) -> None:
logger.info("Stopping %s", type(self).__name__)
self._stop_signaled = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join(self._max_interval)
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
if os.path.exists(self._file_path):
os.remove(self._file_path)
self._log_event("watchdog stopped", None)
def run_once(self) -> None:
self._run_once = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join()
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
if os.path.exists(self._file_path):
os.remove(self._file_path)
@staticmethod
def is_process_running(pid: int):
"""
function to check process is running or not
"""
try:
# Check if the process exists and we can send signals to it
os.kill(pid, 0)
return True
except OSError:
return False
def _watchdog_loop(self) -> None:
# Open the pipe in blocking mode blocks the server thread.
# This is fine for the following reasons:
# 1. No client case usually does not happen.
# 2. We are running the watchdog loop in a separate daemon
# thread, which will not block the process to stop.
try:
fd = open(self._file_path)
except Exception as e:
logger.exception("Could not open the FileTimerServer pipe")
raise
with fd:
self._is_client_started = True
while not self._stop_signaled:
try:
run_once = self._run_once
self._run_watchdog(fd)
if run_once:
break
self._last_progress_time = int(time.time())
except Exception:
logger.exception("Error running watchdog")
def _run_watchdog(self, fd: io.TextIOWrapper) -> None:
timer_requests = self._get_requests(fd, self._max_interval)
self.register_timers(timer_requests)
now = time.time()
reaped_worker_pids = set()
kill_process = False
reap_signal = 0
all_expired_timers = self.get_expired_timers(now)
log_debug_info_for_expired_timers(
self._run_id,
{
pid: [expired_timer.to_json() for expired_timer in expired_timers]
for pid, expired_timers in all_expired_timers.items()
},
)
for worker_pid, expired_timers in all_expired_timers.items():
logger.info(
"Reaping worker_pid=[%s]. Expired timers: %s",
worker_pid,
self._get_scopes(expired_timers),
)
reaped_worker_pids.add(worker_pid)
# In case we have multiple expired timers, we find the first timer
# with a valid signal (>0) in the expiration time order.
expired_timers.sort(key=lambda timer: timer.expiration_time)
signal = 0
expired_timer = None
for timer in expired_timers:
self._log_event("timer expired", timer)
if timer.signal > 0:
signal = timer.signal
expired_timer = timer
break
if signal <= 0:
logger.info(
"No signal specified with worker=[%s]. Do not reap it.", worker_pid
)
continue
if self._reap_worker(worker_pid, signal):
logger.info(
"Successfully reaped worker=[%s] with signal=%s", worker_pid, signal
)
self._log_event("kill worker process", expired_timer)
kill_process = True
reap_signal = signal
else:
logger.error(
"Error reaping worker=[%s]. Will retry on next watchdog.",
worker_pid,
)
if kill_process and reap_signal > 0:
logger.info(
"Terminating the server process=[%s] because of expired timers",
os.getpid(),
)
self._reap_worker(os.getpid(), reap_signal)
self.clear_timers(reaped_worker_pids)
def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
return [r.scope_id for r in timer_requests]
def _get_requests(
self, fd: io.TextIOWrapper, max_interval: float
) -> List[FileTimerRequest]:
start = time.time()
requests = []
while not self._stop_signaled or self._run_once:
# For named pipe, readline() is blocking when at least one writer opens.
# It returns only when flush() is called at the writer side.
# Note that flush() is automatically called inside close().
# After the last writer closes, readline() is not blocking.
# It will return an empty string when it's at end-of-file.
# Since the client side always opens the pipe, writes a message and closes
# the pipe immediately, the readline() call below is not blocking for long.
json_request = fd.readline()
if len(json_request) == 0:
if self._run_once:
break
time.sleep(min(max_interval, 1))
else:
request = json.loads(json_request)
pid = request["pid"]
scope_id = request["scope_id"]
expiration_time = request["expiration_time"]
signal = request["signal"]
requests.append(
FileTimerRequest(
worker_pid=pid,
scope_id=scope_id,
expiration_time=expiration_time,
signal=signal,
)
)
now = time.time()
if now - start > max_interval:
break
return requests
def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_pid
scope_id = request.scope_id
expiration_time = request.expiration_time
self._request_count += 1
key = (pid, scope_id)
# negative expiration is a proxy for a release call
if expiration_time < 0:
if key in self._timers:
del self._timers[key]
else:
self._timers[key] = request
def clear_timers(self, worker_pids: Set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_pids or not FileTimerServer.is_process_running(pid):
del self._timers[(pid, scope_id)]
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[int, List[FileTimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_pid, [])
expired_scopes.append(request)
return expired_timers
def _reap_worker(self, worker_pid: int, signal: int) -> bool:
try:
os.kill(worker_pid, signal)
return True
except ProcessLookupError:
logger.info("Process with pid=%s does not exist. Skipping", worker_pid)
return True
except Exception:
logger.exception("Error terminating pid=%s", worker_pid)
return False
def get_last_progress_time(self) -> int:
return self._last_progress_time if self._is_client_started else int(time.time())