Source code for ravnest.node

from concurrent import futures
import asyncio
import grpc
import threading
import multiprocessing
import threading
from threading import Thread
import numpy as np
import psutil
import pickle
import shutil
import time
from .communication import Communication
from .compute import Compute
from .utils import *
from .strings import *
from .endpoints import GrpcService

from .protos.server_pb2_grpc import add_CommServerServicer_to_server

mp = multiprocessing.get_context('spawn')

[docs] class Node(): """ Responsible for managing the computational and communication aspects of a distributed machine learning model, including model initialization, parameter synchronization, forward and backward passes, loss computation, and communication between different nodes in the system. :param name: The name of the node. Strictly in the format: 'node_0', 'node_17' etc. :type name: str :param model: The PyTorch model associated with the node. :type model: torch.nn.Module :param optimizer: The optimizer used for training the model. :type optimizer: torch.optim.Optimizer :param optimizer_params: Parameters for the optimizer. :type optimizer_params: dict :param lr_scheduler: The learning rate scheduler. :type lr_scheduler: torch.optim.lr_scheduler :param lr_scheduler_params: Parameters for the learning rate scheduler. :type lr_scheduler_params: dict :param lr_step_on_epoch_change: Whether to step the learning rate scheduler on epoch change. :type lr_step_on_epoch_change: bool :param criterion: The loss function. :type criterion: callable :param update_frequency: Frequency of model parameter updates. :type update_frequency: int :param reduce_factor: Frequency at which all-reduce will be triggered i.e. trigger all-reduce every time these many updates are done. :type reduce_factor: int :param labels: Dataloader containing labels. :type labels: torch.utils.data.DataLoader :param test_labels: Test labels for validation. :type test_labels: torch.utils.data.DataLoader :param device: The device on which the model will be run (CPU or GPU). :type device: torch.device :param loss_filename: The filename to save loss values. :type loss_filename: str :param compression: Whether to use compression. :type compression: bool :param kwargs: Additional arguments. :type kwargs: dict """ def __init__(self, name=None, model=None, optimizer=None, optimizer_params={}, lr_scheduler=None, lr_scheduler_params={}, lr_step_on_epoch_change=True, criterion=None, update_frequency = 1, reduce_factor=None, labels=None, test_labels=None, device = torch.device('cpu'), loss_filename='losses.txt', compression=False, average_optim=False, **kwargs): self.manager = mp.Manager() self.forward_lock = mp.Lock() self.backward_lock = mp.Lock() self.latest_weights_lock = mp.Lock() self.reduce_lock = mp.Lock() self.gather_lock = mp.Lock() node_metadata = load_node_json_configs(node_name=name) kwargs.update(node_metadata) self.node_type = kwargs.get('node_type', None) self.template_path = kwargs.get('template_path', None)[:-1] self.local_address = '{}:{}'.format(kwargs.get('local_host', None), kwargs.get('local_port', None)) self.name = name self.loss_filename = loss_filename self.reset() if model is None: self.model = torch.jit.load(kwargs['template_path']+'submod.pt') else: self.model = model self.device = device self.compression = compression if not next(self.model.parameters()).is_cuda: self.model.to(device) self.load_forward_buffer = self.manager.list() self.load_backward_buffer = self.manager.list() self.reduce_ring_buffers = self.manager.dict() self.gather_ring_buffers = self.manager.dict() self.latest_weights_buffer = self.manager.dict() self.reduce_iteration = self.manager.dict() self.gather_iteration = self.manager.dict() self.start_server_flag = self.manager.Value(bool, False) if kwargs.get('ring_ids', None) is not None: self.ring_ids = kwargs.get('ring_ids', None) for ring_id, _ in self.ring_ids.items(): self.reduce_iteration[ring_id] = 0 self.gather_iteration[ring_id] = 0 print('ring ids: ', self.ring_ids) self.rank = kwargs.get('rank', None) print('\n Rank: ', self.rank) self.ring_size = kwargs.get('ring_size', None) self.ring_param_keys = {} data_dict_keys = get_trainable_param_names(model=self.model) for i, ring in enumerate(self.ring_ids.items()): if i < len(self.ring_ids) - 1: keys = data_dict_keys[data_dict_keys.index(ring[1]):data_dict_keys.index(self.ring_ids[ring[0]+1])] else: keys = data_dict_keys[data_dict_keys.index(ring[1]):] self.ring_param_keys[ring[0]] = keys self.param_address_mapping = {} param_addresses = kwargs.get('param_addresses', None) self.retrieve_latest_params_data = {} print(param_addresses) for i, address_to_param in enumerate(param_addresses.items()): if i < len(param_addresses) - 1: keys = data_dict_keys[data_dict_keys.index(address_to_param[1]):data_dict_keys.index(param_addresses[list(param_addresses.keys())[i+1]])] else: keys = data_dict_keys[data_dict_keys.index(address_to_param[1]):] self.retrieve_latest_params_data[address_to_param[0]] = (keys[0], keys[-1]) for param_name in keys: self.param_address_mapping[param_name] = address_to_param[0] print('Ring param keys: ', self.ring_param_keys.keys()) # print('Param address mapping: ', self.param_address_mapping) # print('State dict: ', self.model.state_dict().keys()) if self.node_type == NodeTypes.LEAF: self.criterion = criterion if test_labels is not None: self.test_labels = test_labels self.test_labels_iterator = None if labels is not None: self.labels = labels if isinstance(labels, torch.Tensor): self.labels_iterator = labels else: self.labels_iterator = iter(labels) else: self.criterion = None self.test_labels = None self.test_labels_iterator = None self.labels = None self.labels_iterator = None self.net_val_accuracy = [] self.forward_target_host = kwargs.get('forward_target_host', None) self.forward_target_port = kwargs.get('forward_target_port', None) self.backward_target_host = kwargs.get('backward_target_host', None) self.backward_target_port = kwargs.get('backward_target_port', None) self.output_tensors = {} self.input_tensors = {} self.n_backwards = 0 self.n_forwards = 0 self.forward_pass_id = 0 self.latest_backward_id = 0 self.update_frequency = update_frequency if not reduce_factor: reduce_factor = len(labels) self.reduce_threshold = self.update_frequency * reduce_factor self.submod_file = kwargs.get('submod_file', None) self.node_status = NodeStatus.IDLE self.tensor_id = '0_{}'.format(self.submod_file)#0 self.averaged_params_buffer = {} self.average_no = 0 self.average_optim = average_optim self.cluster_length = kwargs['cluster_length'] self.lr_step_on_epoch_change = lr_step_on_epoch_change if kwargs.get('submod_file', None) is not None: with open('{}{}_input.pkl'.format(kwargs.get('template_path', None), kwargs.get('submod_file', None)), 'rb') as fout: self.input_template = pickle.load(fout) with open('{}{}_output.pkl'.format(kwargs.get('template_path', None), kwargs.get('submod_file', None)), 'rb') as fout: self.output_template = pickle.load(fout) # print(self.input_template) self.model_inputs_template = None if self.node_type == NodeTypes.ROOT: with open('{}model_inputs.pkl'.format(kwargs.get('template_path', None)), 'rb') as fout: self.model_inputs_template = pickle.load(fout) self.optimizer = optimizer(current_model_params_clone(self.model), **optimizer_params) elif self.node_type == NodeTypes.LEAF: self.optimizer = optimizer(self.model.parameters(), **optimizer_params) elif self.node_type == NodeTypes.STEM: self.optimizer = optimizer(current_model_params_clone(self.model), **optimizer_params) self.lr_scheduler = None if lr_scheduler is not None: self.lr_scheduler = lr_scheduler(self.optimizer, **lr_scheduler_params) self.compute_session = Compute(model = self.model, optimizer = self.optimizer, criterion=self.criterion, compression=self.compression, input_tensors = self.input_tensors, latest_weights_buffer = self.latest_weights_buffer, latest_weights_lock=self.latest_weights_lock, tensor_id = self.tensor_id, output_template = self.output_template, input_template = self.input_template, node_type=self.node_type, submod_file=self.submod_file, loss_filename=self.loss_filename, device = self.device) self.comm_session = Communication(name=self.name, model=self.model, optimizer=self.optimizer, node_type=self.node_type, rank=self.rank, ring_size=self.ring_size, ring_param_keys=self.ring_param_keys, ring_ids = self.ring_ids, param_address_mapping=self.param_address_mapping, reduce_lock=self.reduce_lock, gather_lock=self.gather_lock, device=self.device, compression=self.compression, forward_target_host=self.forward_target_host, forward_target_port=self.forward_target_port, backward_target_host=self.backward_target_host, backward_target_port=self.backward_target_port, retrieve_latest_params_data=self.retrieve_latest_params_data, output_tensors=self.output_tensors, input_tensors=self.input_tensors, reduce_ring_buffers=self.reduce_ring_buffers, gather_ring_buffers=self.gather_ring_buffers, reduce_iteration=self.reduce_iteration, gather_iteration=self.gather_iteration, submod_file=self.submod_file, tensor_id=self.tensor_id, averaged_params_buffer=self.averaged_params_buffer, average_no=self.average_no, average_optim = self.average_optim, output_template=self.output_template, model_inputs_template=self.model_inputs_template ) self.start()
[docs] def init_server(self, load_forward_buffer=None, load_backward_buffer=None, reduce_ring_buffers = None, gather_ring_buffers = None, latest_weights_buffer=None, forward_lock=None, backward_lock=None, reduce_lock=None, gather_lock=None, latest_weights_lock=None, reduce_iteration = None, gather_iteration = None): """Initialize the gRPC server for handling communication with other nodes. :param load_forward_buffer: Shared buffer for incoming forward pass data, defaults to None :type load_forward_buffer: multiprocessing.Manager.list, optional :param load_backward_buffer: Shared buffer for incoming backward pass data, defaults to None :type load_backward_buffer: multiprocessing.Manager.list, optional :param reduce_ring_buffers: Shared dictionary for reduce operation buffers, defaults to None :type reduce_ring_buffers: multiprocessing.Manager.dict, optional :param gather_ring_buffers: Shared dictionary for gather operation buffers, defaults to None :type gather_ring_buffers: multiprocessing.Manager.dict, optional :param forward_lock: Lock for synchronizing access to forward buffers, defaults to None :type forward_lock: multiprocessing.Lock, optional :param backward_lock: Lock for synchronizing access to backward buffers, defaults to None :type backward_lock: multiprocessing.Lock, optional :param reduce_lock: Lock for synchronizing reduce operations, defaults to None :type reduce_lock: multiprocessing.Lock, optional :param gather_lock: Lock for synchronizing gather operations, defaults to None :type gather_lock: multiprocessing.Lock, optional :param reduce_iteration: Shared dictionary for reduce iteration counts, defaults to None :type reduce_iteration: multiprocessing.Manager.dict, optional :param gather_iteration: Shared dictionary for gather iteration counts, defaults to None :type gather_iteration: multiprocessing.Manager.dict, optional """ self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) add_CommServerServicer_to_server(GrpcService( load_forward_buffer=load_forward_buffer, load_backward_buffer=load_backward_buffer, reduce_ring_buffers=reduce_ring_buffers, gather_ring_buffers=gather_ring_buffers, latest_weights_buffer=latest_weights_buffer, forward_lock=forward_lock, backward_lock=backward_lock, reduce_lock=reduce_lock, gather_lock=gather_lock,latest_weights_lock=latest_weights_lock, reduce_iteration = reduce_iteration, gather_iteration = gather_iteration), self.server) print('Length of forward buffer: ', len(load_backward_buffer), os.getpid())
[docs] def grpc_server_serve(self): """Starts the gRPC server and listens for incoming connections. """ self.server.add_insecure_port(self.local_address) self.server.start() print('Listening on : ', self.local_address) self.start_server_flag.value = True self.server.wait_for_termination()
[docs] def start_grpc_server(self): """Start the gRPC server asynchronously. Uses asyncio to start the gRPC server in an asynchronous manner. """ asyncio.get_event_loop().run_until_complete(self.grpc_server_serve())
[docs] def start(self): """Start the gRPC server and buffer checking threads. Spawns a process for serving gRPC requests and starts a thread for checking and processing incoming data buffers. """ print('Main process: ', os.getpid()) serve_process = mp.Process(target=self.grpc_server_serve, daemon=True) serve_process.start() while not self.start_server_flag.value: time.sleep(0.5) buffer_thread = threading.Thread(target=self.check_load_forward_buffer, daemon=True) buffer_thread.start()
[docs] def check_load_forward_buffer(self): """Check and process the load forward buffer for incoming data. Continuously monitors the load forward buffer and processes incoming data for forward pass computations. """ while True: send_trigger_threads = [] if len(self.load_backward_buffer) != 0: self.backward_lock.acquire(block=True) value = self.load_backward_buffer[0] del self.load_backward_buffer[0] self.backward_lock.release() action = value['action'] getattr(self, action)(value, send_trigger_threads) self.node_status = NodeStatus.IDLE if len(self.load_forward_buffer) != 0: self.forward_lock.acquire(block=True) value = self.load_forward_buffer[0] del self.load_forward_buffer[0] self.forward_lock.release() action = value['action'] if action == ActionTypes.FORWARD and self.node_type == NodeTypes.LEAF: action = ActionTypes.FIND_LOSS if action == ActionTypes.NO_GRAD_FORWARD and self.node_type == NodeTypes.LEAF: action = ActionTypes.VAL_ACCURACY getattr(self, action)(value, send_trigger_threads) if len(send_trigger_threads)>0: for send_threads in send_trigger_threads: send_threads.join() self.node_status = NodeStatus.IDLE
[docs] def forward_compute(self, tensors=None, **kwargs): """Initiate a forward computation request. Adds the forward computation request to the load forward buffer, ensuring synchronization and handling of computational resources. :param tensors: Input tensors for the forward computation, defaults to None :type tensors: torch.Tensor, optional :param kwargs: Additional keyword arguments for the computation, defaults to {} :type kwargs: dict, optional """ data = {'data':tensors, 'kwargs':kwargs, 'action': ActionTypes.ROOT_FORWARD} while self.forward_pass_id - self.latest_backward_id > self.cluster_length: time.sleep(0) if self.n_forwards % self.reduce_threshold == 0: self.wait_for_backwards() self.forward_lock.acquire(block=True) self.load_forward_buffer.append(data) self.forward_lock.release() self.root_compute = False while not self.root_compute: time.sleep(0)
[docs] def no_grad_forward_compute(self, tensors=None, output_type=None): """Perform a forward pass without computing gradients. Executes a forward pass without gradient computation and sends the output to the designated target host and port. :param tensors: Input tensors for the forward pass, defaults to None :type tensors: torch.Tensor, optional :param output_type: Type of output computation (e.g., validation accuracy), defaults to None :type output_type: str, optional """ tensors = tensors.to(self.device) # self.comm_session.parallel_ring_reduce() self.node_status = NodeStatus.FORWARD output = self.compute_session.root_no_grad_forward_compute(tensors=tensors) payload = self.comm_session.create_no_grad_forward_payload(output, tensors=tensors) final_payload = {} final_payload[self.submod_file] = payload sent_data = { 'data': final_payload, 'action': ActionTypes.NO_GRAD_FORWARD, 'output_type': output_type } self.comm_session.trigger_send(sent_data, type=ActionTypes.FORWARD, target_host=self.forward_target_host, target_port=self.forward_target_port) print('No Grad forward compute done') self.node_status = NodeStatus.IDLE
def root_forward(self, value, send_threads): tensors = value['data'] kwargs = value['kwargs'] if tensors is not None: tensors = tensors.to(self.device) modified_kwargs = {} for kwarg_key, kwarg_val in kwargs.items(): if isinstance(kwarg_val, torch.Tensor): modified_kwargs['l_'+kwarg_key+'_'] = kwarg_val.to(self.device) else: modified_kwargs['l_'+kwarg_key+'_'] = kwarg_val self.node_status = NodeStatus.FORWARD print('Before Root Forward: ') check_gpu_usage() output = self.compute_session.root_forward_compute(tensors, self.forward_pass_id, **modified_kwargs) print('After Root Forward: ') check_gpu_usage() payload = self.comm_session.create_forward_payload(output, tensors=tensors) final_payload = {} final_payload[self.submod_file] = payload sent_data = {'forward_pass_id':self.forward_pass_id, 'data': final_payload, 'action': ActionTypes.FORWARD} print('Forward compute done for: ', self.forward_pass_id) self.forward_pass_id += 1 self.comm_session.trigger_send(sent_data, type=ActionTypes.FORWARD, target_host=self.forward_target_host, target_port=self.forward_target_port) self.n_forwards += 1 # print('Forward compute done for: ', self.tensor_id) self.root_compute = True self.node_status = NodeStatus.IDLE def forward(self, value, send_threads): print('n_backwards in FORWARD: ', self.n_backwards) self.node_status = NodeStatus.FORWARD data = value['data'] forward_pass_id = value['forward_pass_id'] print('Start of forward: ', forward_pass_id) output = self.compute_session.middle_forward_compute(data, forward_pass_id=forward_pass_id) payload = self.comm_session.create_forward_payload(output) final_payload = data final_payload[self.submod_file] = payload sent_data = {'forward_pass_id':forward_pass_id, 'data': final_payload, 'action': ActionTypes.FIND_LOSS} t = Thread(target=self.comm_session.trigger_send, args=(sent_data, ActionTypes.FORWARD, self.forward_target_host, self.forward_target_port,)) send_threads.append(t) t.start() self.n_forwards += 1 print('Forward Done Used RAM %: ', psutil.virtual_memory().percent) def no_grad_forward(self, value, send_threads): # self.comm_session.parallel_ring_reduce() self.node_status = NodeStatus.FORWARD print('No grad forward') data = value['data'] output = self.compute_session.middle_no_grad_forward_compute(data) payload = self.comm_session.create_no_grad_forward_payload(output) final_payload = data final_payload[self.submod_file] = payload sent_data = { 'data': final_payload, 'action': value['output_type'] } t = Thread(target=self.comm_session.trigger_send, args=(sent_data, ActionTypes.FORWARD, self.forward_target_host, self.forward_target_port,)) send_threads.append(t) t.start() def backward(self, value, send_threads): self.node_status = NodeStatus.BACKWARD gradient_dict = value['data'] forward_pass_id = value['forward_pass_id'] epoch_change = value['epoch_change'] if epoch_change and self.lr_step_on_epoch_change: if self.lr_scheduler is not None: self.lr_scheduler.step() self.latest_backward_id = forward_pass_id print('Start of backward: ', forward_pass_id) update_flag = False if (self.n_backwards + 1) % self.update_frequency == 0: update_flag = True pass_grad_keys = self.compute_session.middle_backward_compute(gradient_dict, forward_pass_id, update_flag=update_flag) if update_flag and not self.lr_step_on_epoch_change: if self.lr_scheduler is not None: self.lr_scheduler.step() if self.node_type != NodeTypes.ROOT: gradients = self.comm_session.create_backward_payload(forward_pass_id=forward_pass_id) for pass_key in pass_grad_keys: if pass_key in gradients.keys(): assert gradient_dict[pass_key]['dtype'] == gradients[pass_key]['dtype'] gradients[pass_key] = {'dtype': gradients[pass_key]['dtype'], 'data': gradient_dict[pass_key]['data'].add_(gradients[pass_key]['data'])} else: gradients[pass_key] = gradient_dict[pass_key] sent_data = {'action':ActionTypes.BACKWARD, 'forward_pass_id':forward_pass_id, 'data':gradients, 'epoch_change':epoch_change, } t = Thread(target=self.comm_session.trigger_send, args=(sent_data, ActionTypes.BACKWARD, self.backward_target_host, self.backward_target_port,)) send_threads.append(t) t.start() if self.input_tensors.get(forward_pass_id, None) is not None: del self.input_tensors[forward_pass_id] print('Backward done, Used RAM %: ', psutil.virtual_memory().percent) self.n_backwards += 1 if self.n_backwards % self.reduce_threshold == 0: # print('\nPre AVeraged params: ', self.compute_session.model.state_dict()[list(self.compute_session.model.state_dict().keys())[0]]) self.comm_session.parallel_ring_reduce() # self.compute_session.current_version += 1 # self.compute_session.version_to_param[self.compute_session.current_version] = self.compute_session.get_params_clone() # self.latest_weights_lock.acquire(block=True) # self.latest_weights_buffer['state_dict'] = self.compute_session.version_to_param[self.compute_session.current_version] # self.latest_weights_lock.release() self.compute_session.update_model_version() # print('\nAVeraged params: ', self.compute_session.model.state_dict()[list(self.compute_session.model.state_dict().keys())[0]]) # if self.device.type == 'cuda': # torch.cuda.synchronize() def find_loss(self, value, send_threads): self.node_status = NodeStatus.FORWARD epoch_change = False targets = next(self.labels_iterator, None) if targets is None: epoch_change = self.lr_step_on_epoch_change self.labels_iterator = iter(self.labels) targets = next(self.labels_iterator) if epoch_change: if self.lr_scheduler is not None: self.lr_scheduler.step() print('\n ---------------------- Reset Data Iterator ------------------------') # print('For: ', value['forward_pass_id']) # print('X_train: ', targets[0][0][0]) # print('y_train: ', targets[1]) # targets = targets[1].to(self.device) # targets = targets.to(self.device) # For BERT update_flag = False if (self.n_backwards + 1) % self.update_frequency == 0: update_flag = True data = value['data'] model_args = self.compute_session.leaf_find_loss(data, targets=targets, update_flag=update_flag) if update_flag and not self.lr_step_on_epoch_change: if self.lr_scheduler is not None: self.lr_scheduler.step() gradients = self.comm_session.create_backward_payload(model_args=model_args) sent_data = {'action':ActionTypes.BACKWARD, 'data':gradients, 'forward_pass_id':value['forward_pass_id'], 'epoch_change':epoch_change } t = Thread(target=self.comm_session.trigger_send, args=(sent_data, ActionTypes.BACKWARD, self.backward_target_host, self.backward_target_port,)) send_threads.append(t) t.start() # print('find_loss done. Used RAM %: ', psutil.virtual_memory().percent) self.n_backwards += 1 # print('N_backwards: ', self.n_backwards) if self.n_backwards % self.reduce_threshold == 0: # print('\nPre AVeraged params: ', self.compute_session.model.state_dict()['L__self___bert_encoder_layer_9_output_dense.weight'])#list(self.compute_session.model.state_dict().keys())[0]]) self.comm_session.parallel_ring_reduce() # print('\nAVeraged params: ', self.compute_session.model.state_dict()['L__self___bert_encoder_layer_9_output_dense.weight'])#[list(self.compute_session.model.state_dict().keys())[0]]) # if self.device.type == 'cuda': # # print('Sync') # torch.cuda.synchronize() def val_accuracy(self, value, send_threads): data = value['data'] model_args = self.compute_session.create_no_grad_model_args(data) if self.test_labels_iterator is None: if isinstance(self.test_labels, torch.Tensor): self.test_labels_iterator = self.test_labels else: self.test_labels_iterator = iter(self.test_labels) self.model.eval() with torch.no_grad(): y_pred = self.model(*model_args) _, y_pred_tags = torch.max(y_pred, dim=1) y_test = next(self.test_labels_iterator, None) if y_test is None: self.test_labels_iterator = iter(self.test_labels) y_test = next(self.test_labels_iterator) y_test = y_test[1].to(self.device) #for cnn y_test = torch.argmax(y_test, dim=1) correct_pred = (y_pred_tags == y_test).float() val_acc = correct_pred.sum() / len(y_test) val_acc = torch.round(val_acc * 100) self.net_val_accuracy.append(val_acc.item()) if len(self.net_val_accuracy) == len(self.test_labels_iterator): validation_accuracy = round(sum(self.net_val_accuracy) / len(self.net_val_accuracy), 2) print('Validation Accuracy: ', validation_accuracy) f = open("val_accuracies.txt", "a") f.write(str(validation_accuracy) + '\n') f.close() self.net_val_accuracy = [] def accuracy(self, value, send_threads): # self.comm_session.parallel_ring_reduce() print('Finding accuracy') data = value['data'] model_args = self.compute_session.create_no_grad_model_args(data) self.model.eval() with torch.no_grad(): y_pred = self.model(*model_args) y_pred = np.argmax(y_pred.detach().cpu().numpy(), axis=-1) y_test = np.argmax(self.test_labels, axis=-1) accuracy = np.sum(y_pred == y_test, axis=0)/len(y_test) print('\nTest Accuracy: ', accuracy) def prediction(self, value, send_threads): data = value['data'] print('Prediction: ', data) model_args = self.create_no_grad_model_args(data) self.model.eval() with torch.no_grad(): pred = self.model(*model_args) print('Predicted: ', pred) def save_submodel(self, value, send_threads): script = torch.jit.script(self.model) script.save('{}/{}.pt'.format(self.template_path, self.submod_file)) os.remove('{}/submod.pt'.format(self.template_path)) if self.node_type != NodeTypes.LEAF: t = Thread(target=self.comm_session.trigger_send, args=({'action': ActionTypes.SAVE_SUBMODEL}, ActionTypes.FORWARD, self.forward_target_host, self.forward_target_port,)) send_threads.append(t) t.start() print('SAVE done')
[docs] def wait_for_backwards(self): """Wait until all backward passes are completed. Checks and waits until all initiated backward computations are finished before proceeding with further operations. """ while self.n_backwards < self.n_forwards: time.sleep(1)
[docs] def trigger_save_submodel(self): """Trigger saving of the current submodel state. Saves the current state of the model to disk and optionally sends the updated model state to the designated target host and port. """ script = torch.jit.script(self.model) os.makedirs(self.template_path, exist_ok=True) script.save('{}/{}.pt'.format(self.template_path,self.submod_file)) os.remove('{}/submod.pt'.format(self.template_path)) self.comm_session.trigger_send({'action': ActionTypes.SAVE_SUBMODEL}, type=ActionTypes.FORWARD, target_host=self.forward_target_host, target_port=self.forward_target_port) print('SAVE done')
def update_with_latest_weights(self): latest_sd = self.comm_session.get_latest_weights() load_state_dict_conserve_versions(self.compute_session.model, latest_sd) self.compute_session.update_model_version() print('Model latest weights loaded!')
[docs] def reset(self): """Reset the node's auxiliary and stateful data. Cleans up temporary directories and files associated with the node, preparing it for a fresh start. """ if os.path.exists('{}_aux'.format(self.name)): shutil.rmtree('{}_aux'.format(self.name)) if os.path.exists('trained'): shutil.rmtree('trained') if os.path.exists(self.loss_filename): os.remove(self.loss_filename) if os.path.exists('val_accuracies.txt'): os.remove('val_accuracies.txt')
def __getstate__(self): return dict( forward_lock = self.forward_lock, backward_lock = self.backward_lock, reduce_lock = self.reduce_lock, gather_lock = self.gather_lock, latest_weights_lock = self.latest_weights_lock, local_address = self.local_address, load_forward_buffer = self.load_forward_buffer, load_backward_buffer = self.load_backward_buffer, latest_weights_buffer = self.latest_weights_buffer, reduce_ring_buffers = self.reduce_ring_buffers, gather_ring_buffers = self.gather_ring_buffers, reduce_iteration = self.reduce_iteration, gather_iteration = self.gather_iteration, start_server_flag = self.start_server_flag ) def __setstate__(self, state): self.local_address = state['local_address'] self.start_server_flag = state['start_server_flag'] self.init_server(load_forward_buffer=state['load_forward_buffer'], load_backward_buffer=state['load_backward_buffer'], reduce_ring_buffers= state['reduce_ring_buffers'], gather_ring_buffers= state['gather_ring_buffers'], latest_weights_buffer=state['latest_weights_buffer'], forward_lock=state['forward_lock'], backward_lock=state['backward_lock'], reduce_lock=state['reduce_lock'], gather_lock=state['gather_lock'], latest_weights_lock = state['latest_weights_lock'], reduce_iteration = state['reduce_iteration'], gather_iteration = state['gather_iteration'] )