<aside> 💡 이 문서는 Elastic Horovod의 Autoscaling 과정을 설명합니다.

</aside>

1. 호스트 스크립트 파일 확인

def _discover_hosts(self):
    first_update = True
    while not self._shutdown.is_set():
        self._wait_hosts_cond.acquire()
        try:
            update_res = self._host_manager.update_available_hosts()
            if update_res != HostUpdateResult.no_update:
                self._notify_workers_host_changes(self._host_manager.current_hosts, update_res)
                self._wait_hosts_cond.notify_all()
        except RuntimeError as e:
            ...
        finally:
            self._wait_hosts_cond.release()
        first_update = False
        self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)

Discover thread는 앞서 설명드렸듯 DISCOVER_HOSTS_FREQUENCY_SECS 마다 HostManager의 update_available_hosts() 함수를 호출하여 호스트 스크립트 파일의 호스트 정보가 변경되었는지 확인합니다.

def update_available_hosts(self):
        def check_update(cur_host_slots, prev_host_slots):
            res = HostUpdateResult.no_update

            for prev_h in prev_host_slots:
                if prev_h not in cur_host_slots:
                    # prev_h is a removed host
                    res |= HostUpdateResult.removed

            for h in cur_host_slots:
                if h not in prev_host_slots:
                    # h is an added host
                    res |= HostUpdateResult.added
                elif cur_host_slots[h] > prev_host_slots[h]:
                    # h has more slots added
                    res |= HostUpdateResult.added
                elif cur_host_slots[h] < prev_host_slots[h]:
                    # h has removed some slots
                    res |=  HostUpdateResult.removed
            return res

        prev_host_slots = self._current_hosts.host_slots
        prev_host_assignment_order = self._current_hosts.host_assignment_order
        host_slots = self._discovery.find_available_hosts_and_slots()
        if prev_host_slots != host_slots:
            available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()])
            host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order)
            self._current_hosts = DiscoveredHosts(host_slots=host_slots,
                                                  host_assignment_order=host_assignment_order)
            return check_update(self._current_hosts.host_slots, prev_host_slots)
        else:
            return HostUpdateResult.no_update

여기서 self._discovery는 처음 Elastic Horovod 실행 시 --host-discovery-script 옵션을 통해 전달받은 인자를 기반으로 생성한 HostDiscoveryScript 객체입니다.

update_available_hosts() 함수에서는 HostDiscoveryScript 객체의 find_available_hosts_and_slots() 함수를 수행하여 새 호스트 정보를 가져오고 기존의 호스트 정보와 비교합니다.

check_update() 함수를 통해 비교한 뒤 호스트 정보 변경 정보를 담는 flag를 반환합니다.

flag 값이 no_update가 아니면 _notify_workers_host_changes() 함수를 호출합니다.

2. 새 호스트 정보와 타임스탬프 전송

def _notify_workers_host_changes(self, current_hosts, update_res):
        next_host_assignments = {}
        if current_hosts.count_available_slots() >= self._min_np:
            # Assignments are required to be stable via contract
            next_host_assignments, _ = self._get_host_assignments(current_hosts)

        if next_host_assignments == self.host_assignments:
            # Skip notifying workers when host changes would not result in changes of host assignments
            logging.debug('no host assignment changes, skipping notifications')
            return

        coordinator_slot_info = self.get_coordinator_info()
        if not coordinator_slot_info:
            logging.debug('no coordinator info, skipping notifications')
            return

        coordinator_client = self.get_worker_client(coordinator_slot_info)
        if not coordinator_client:
            logging.debug('no coordinator client, skipping notifications')
            return

        timestamp = _epoch_time_s()
        try:
            coordinator_client.notify_hosts_updated(timestamp, update_res)
        except:
            if self._verbose >= 2:
                logging.exception('failed to notify {}[{}] of host updates'
                                  .format(coordinator_slot_info.hostname,
                                          coordinator_slot_info.local_rank))

_notify_workers_host_changes() 함수에서는 get_coordinator_info() 함수를 호출하여 ElasticDriver가 가지고 있는 Worker 정보를 이용해 코디네이터가 누구인지 확인하고, get_worker_client() 함수를 호출하여 코디네이터의 WorkerNotificationClient를 가져옵니다. 그리고, 해당 WorkerNotificationClient의 notify_hosts_updated() 함수를 호출합니다.

def notify_hosts_updated(self, timestamp, update_res):
        self._send(HostsUpdatedRequest(timestamp, update_res))
def _send(self, req, stream=None):                                                                                                                                                                                                                                                          """
        Sends the request and returns the response object.
        Streaming data response is transferred to the optional stream parameter.
        """
        # Since all the addresses were vetted, use the first one.
        addr = list(self._addresses.values())[0][0]
        return self._send_one(addr, req, stream)
def _send_one(self, addr, req, stream=None):
        """
        Send the request to the server and retry on errors.
        Streams data that follow a AckStreamResponse to the given utf8 text stream.
        """
        for iter in range(self._attempts):
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            try:
                sock.connect(addr)
                rfile = sock.makefile('rb')
                wfile = sock.makefile('wb')
                try:
                    self._wire.write(req, wfile)
                    resp = self._wire.read(rfile)
                    if stream and isinstance(resp, AckStreamResponse):
                        # stream and byte content in rfile are expected to be utf8 text
                        from encodings.utf_8 import StreamReader
                        r = StreamReader(rfile)
                        shutil.copyfileobj(r, stream)
                    return resp
                finally:
                    rfile.close()
                    wfile.close()
            except:
                if iter == self._attempts - 1:
                    # Raise exception on the last retry.
                    raise
            finally:
                sock.close()