Source code for torch.distributed.elastic.rendezvous.etcd_store
# mypy: allow-untyped-defs# Copyright (c) Facebook, Inc. and its affiliates.# All rights reserved.## This source code is licensed under the BSD-style license found in the# LICENSE file in the root directory of this source tree.importdatetimeimportrandomimporttimefrombase64importb64decode,b64encodefromtypingimportOptionalimportetcd# type: ignore[import]# pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.fromtorch.distributedimportStore# Delay (sleep) for a small random amount to reduce CAS failures.# This does not affect correctness, but will reduce requests to etcd server.defcas_delay():time.sleep(random.uniform(0,0.1))# pyre-fixme[11]: Annotation `Store` is not defined as a type.
[docs]classEtcdStore(Store):""" Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. This is the store object returned by ``EtcdRendezvous``. """def__init__(self,etcd_client,etcd_store_prefix,# Default timeout same as in c10d/Store.hpptimeout:Optional[datetime.timedelta]=None,):super().__init__()# required for pybind trampoline.self.client=etcd_clientself.prefix=etcd_store_prefixiftimeoutisnotNone:self.set_timeout(timeout)ifnotself.prefix.endswith("/"):self.prefix+="/"
[docs]defset(self,key,value):""" Write a key/value pair into ``EtcdStore``. Both key and value may be either Python ``str`` or ``bytes``. """self.client.set(key=self.prefix+self._encode(key),value=self._encode(value))
[docs]defget(self,key)->bytes:""" Get a value by key, possibly doing a blocking wait. If key is not immediately present, will do a blocking wait for at most ``timeout`` duration or until the key is published. Returns: value ``(bytes)`` Raises: LookupError - If key still not published after timeout """b64_key=self.prefix+self._encode(key)kvs=self._try_wait_get([b64_key])ifkvsisNone:raiseLookupError(f"Key {key} not found in EtcdStore")returnself._decode(kvs[b64_key])
[docs]defadd(self,key,num:int)->int:""" Atomically increment a value by an integer amount. The integer is represented as a string using base 10. If key is not present, a default value of ``0`` will be assumed. Returns: the new (incremented) value """b64_key=self._encode(key)# c10d Store assumes value is an integer represented as a decimal stringtry:# Assume default value "0", if this key didn't yet:node=self.client.write(key=self.prefix+b64_key,value=self._encode(str(num)),# i.e. 0 + numprevExist=False,)returnint(self._decode(node.value))exceptetcd.EtcdAlreadyExist:passwhileTrue:# Note: c10d Store does not have a method to delete keys, so we# can be sure it's still there.node=self.client.get(key=self.prefix+b64_key)new_value=self._encode(str(int(self._decode(node.value))+num))try:node=self.client.test_and_set(key=node.key,value=new_value,prev_value=node.value)returnint(self._decode(node.value))exceptetcd.EtcdCompareFailed:cas_delay()
[docs]defwait(self,keys,override_timeout:Optional[datetime.timedelta]=None):""" Wait until all of the keys are published, or until timeout. Raises: LookupError - if timeout occurs """b64_keys=[self.prefix+self._encode(key)forkeyinkeys]kvs=self._try_wait_get(b64_keys,override_timeout)ifkvsisNone:raiseLookupError("Timeout while waiting for keys in EtcdStore")
# No return value on success
[docs]defcheck(self,keys)->bool:"""Check if all of the keys are immediately present (without waiting)."""b64_keys=[self.prefix+self._encode(key)forkeyinkeys]kvs=self._try_wait_get(b64_keys,override_timeout=datetime.timedelta(microseconds=1),# as if no wait)returnkvsisnotNone
## Encode key/value data in base64, so we can store arbitrary binary data# in EtcdStore. Input can be `str` or `bytes`.# In case of `str`, utf-8 encoding is assumed.#def_encode(self,value)->str:iftype(value)==bytes:returnb64encode(value).decode()eliftype(value)==str:returnb64encode(value.encode()).decode()raiseValueError("Value must be of type str or bytes")## Decode a base64 string (of type `str` or `bytes`).# Return type is `bytes`, which is more convenient with the Store interface.#def_decode(self,value)->bytes:iftype(value)==bytes:returnb64decode(value)eliftype(value)==str:returnb64decode(value.encode())raiseValueError("Value must be of type str or bytes")## Get all of the (base64-encoded) etcd keys at once, or wait until all the keys# are published or timeout occurs.# This is a helper method for the public interface methods.## On success, a dictionary of {etcd key -> etcd value} is returned.# On timeout, None is returned.#def_try_wait_get(self,b64_keys,override_timeout=None):timeout=self.timeoutifoverride_timeoutisNoneelseoverride_timeout# type: ignore[attr-defined]deadline=time.time()+timeout.total_seconds()whileTrue:# Read whole directory (of keys), filter only the ones waited forall_nodes=Nonetry:all_nodes=self.client.get(key=self.prefix)req_nodes={node.key:node.valuefornodeinall_nodes.childrenifnode.keyinb64_keys}iflen(req_nodes)==len(b64_keys):# All keys are availablereturnreq_nodesexceptetcd.EtcdKeyNotFound:passwatch_timeout=deadline-time.time()ifwatch_timeout<=0:returnNonetry:index=all_nodes.etcd_index+1ifall_nodeselse0self.client.watch(key=self.prefix,recursive=True,timeout=watch_timeout,index=index,)exceptetcd.EtcdWatchTimedOut:iftime.time()>=deadline:returnNoneelse:continueexceptetcd.EtcdEventIndexCleared:continue
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.