<aside> 💡 이 문서는 Elastic Horovod의 Autoscaling 과정을 설명합니다.
</aside>
def _discover_hosts(self):
first_update = True
while not self._shutdown.is_set():
self._wait_hosts_cond.acquire()
try:
update_res = self._host_manager.update_available_hosts()
if update_res != HostUpdateResult.no_update:
self._notify_workers_host_changes(self._host_manager.current_hosts, update_res)
self._wait_hosts_cond.notify_all()
except RuntimeError as e:
...
finally:
self._wait_hosts_cond.release()
first_update = False
self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)
Discover thread는 앞서 설명드렸듯 DISCOVER_HOSTS_FREQUENCY_SECS 마다 HostManager의 update_available_hosts() 함수를 호출하여 호스트 스크립트 파일의 호스트 정보가 변경되었는지 확인합니다.
def update_available_hosts(self):
def check_update(cur_host_slots, prev_host_slots):
res = HostUpdateResult.no_update
for prev_h in prev_host_slots:
if prev_h not in cur_host_slots:
# prev_h is a removed host
res |= HostUpdateResult.removed
for h in cur_host_slots:
if h not in prev_host_slots:
# h is an added host
res |= HostUpdateResult.added
elif cur_host_slots[h] > prev_host_slots[h]:
# h has more slots added
res |= HostUpdateResult.added
elif cur_host_slots[h] < prev_host_slots[h]:
# h has removed some slots
res |= HostUpdateResult.removed
return res
prev_host_slots = self._current_hosts.host_slots
prev_host_assignment_order = self._current_hosts.host_assignment_order
host_slots = self._discovery.find_available_hosts_and_slots()
if prev_host_slots != host_slots:
available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()])
host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order)
self._current_hosts = DiscoveredHosts(host_slots=host_slots,
host_assignment_order=host_assignment_order)
return check_update(self._current_hosts.host_slots, prev_host_slots)
else:
return HostUpdateResult.no_update
여기서 self._discovery는 처음 Elastic Horovod 실행 시 --host-discovery-script 옵션을 통해 전달받은 인자를 기반으로 생성한 HostDiscoveryScript 객체입니다.
HostDiscoveryScriptupdate_available_hosts() 함수에서는 HostDiscoveryScript 객체의 find_available_hosts_and_slots() 함수를 수행하여 새 호스트 정보를 가져오고 기존의 호스트 정보와 비교합니다.
find_available_hosts_and_slots()check_update() 함수를 통해 비교한 뒤 호스트 정보 변경 정보를 담는 flag를 반환합니다.
flagflag 값이 no_update가 아니면 _notify_workers_host_changes() 함수를 호출합니다.
def _notify_workers_host_changes(self, current_hosts, update_res):
next_host_assignments = {}
if current_hosts.count_available_slots() >= self._min_np:
# Assignments are required to be stable via contract
next_host_assignments, _ = self._get_host_assignments(current_hosts)
if next_host_assignments == self.host_assignments:
# Skip notifying workers when host changes would not result in changes of host assignments
logging.debug('no host assignment changes, skipping notifications')
return
coordinator_slot_info = self.get_coordinator_info()
if not coordinator_slot_info:
logging.debug('no coordinator info, skipping notifications')
return
coordinator_client = self.get_worker_client(coordinator_slot_info)
if not coordinator_client:
logging.debug('no coordinator client, skipping notifications')
return
timestamp = _epoch_time_s()
try:
coordinator_client.notify_hosts_updated(timestamp, update_res)
except:
if self._verbose >= 2:
logging.exception('failed to notify {}[{}] of host updates'
.format(coordinator_slot_info.hostname,
coordinator_slot_info.local_rank))
_notify_workers_host_changes() 함수에서는 get_coordinator_info() 함수를 호출하여 ElasticDriver가 가지고 있는 Worker 정보를 이용해 코디네이터가 누구인지 확인하고, get_worker_client() 함수를 호출하여 코디네이터의 WorkerNotificationClient를 가져옵니다. 그리고, 해당 WorkerNotificationClient의 notify_hosts_updated() 함수를 호출합니다.
def notify_hosts_updated(self, timestamp, update_res):
self._send(HostsUpdatedRequest(timestamp, update_res))
def _send(self, req, stream=None): """
Sends the request and returns the response object.
Streaming data response is transferred to the optional stream parameter.
"""
# Since all the addresses were vetted, use the first one.
addr = list(self._addresses.values())[0][0]
return self._send_one(addr, req, stream)
def _send_one(self, addr, req, stream=None):
"""
Send the request to the server and retry on errors.
Streams data that follow a AckStreamResponse to the given utf8 text stream.
"""
for iter in range(self._attempts):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.connect(addr)
rfile = sock.makefile('rb')
wfile = sock.makefile('wb')
try:
self._wire.write(req, wfile)
resp = self._wire.read(rfile)
if stream and isinstance(resp, AckStreamResponse):
# stream and byte content in rfile are expected to be utf8 text
from encodings.utf_8 import StreamReader
r = StreamReader(rfile)
shutil.copyfileobj(r, stream)
return resp
finally:
rfile.close()
wfile.close()
except:
if iter == self._attempts - 1:
# Raise exception on the last retry.
raise
finally:
sock.close()