Source code for ravnest.operations.utils

import torch
import setuptools
import pickle
from pip._internal.operations.freeze import freeze

from torch.fx import Tracer
from pippy.IR import Pipe
from pippy import split_into_equal_size

from .pippy_utils import split_on_proportions
from .genetic import genetic_algorithm
from .cluster import Cluster
from .node import Node
import numpy as np
import random
import json
import os
import shutil
import copy
import torchinfo
import math
import inspect

def spawn_node_pool(num_nodes=None, mode=None, ram_variants=None, bandwidth_variants=None):
    node_pool = []
    if mode == 'load_from_configs':
        file = open('node_data/node_configs.json')
        node_configs = json.load(file)
        num_nodes = len(node_configs.keys())
        total_ram = 0
        for nid in range(num_nodes):
            # print(nid, node_configs[str(nid)]['IP'], node_configs[str(nid)]['benchmarks'])
            node_configs[str(nid)]['benchmarks']['ram'] *= 1024
            total_ram += node_configs[str(nid)]['benchmarks']['ram']
            node = Node(node_id=nid,
                        address=node_configs[str(nid)]['IP'],
                        benchmarks=node_configs[str(nid)]['benchmarks'])
            node_pool.append(node)

    else:
        for nid in range(num_nodes):
            benchmarks = {'ram':random.choice(ram_variants), 'bandwidth':random.choice(bandwidth_variants)}
            random_ip_address = '.'.join(str(np.random.randint(0, 255)) for _ in range(4))
            node_configs[str(nid)]['benchmarks']['ram'] *= 1024
            total_ram += node_configs[str(nid)]['benchmarks']['ram']
            node = Node(node_id=nid,
                        address=random_ip_address,
                        benchmarks=benchmarks)
            node_pool.append(node)
    return node_pool, total_ram

def cluster_formation(full_model_size, node_pool):
    prelim_clusters = genetic_algorithm(node_pool, full_model_size)
    prelim_clusters = dict(sorted(prelim_clusters.items()))
    prelim_clusters = {new_key: prelim_clusters[old_key] for new_key, old_key in enumerate(prelim_clusters.keys())}
    clusters = [Cluster(cid) for cid in prelim_clusters.keys()]

    # assigning clusters to nodes based on genetic algorithm output
    for cid, nodes in prelim_clusters.items():
        for node in nodes:
            clusters[cid].add_node(node)

    for cluster in clusters:
        calculate_split_percentages(cluster=cluster, full_model_size=full_model_size)
        calculate_cluster_power(cluster=cluster)
        # cluster.state_dict = state_dict
    return clusters

def round_percentages(percentages):
    # Uses the "largest remainder" method, also known as the Hare-Niemeyer method
    integer_parts = [int(p) for p in percentages]
    remainders = [p - int_part for p, int_part in zip(percentages, integer_parts)]
    total = sum(integer_parts)
    remainder = 100 - total
    while remainder > 0:
        index = remainders.index(max(remainders))
        integer_parts[index] += 1
        remainders[index] = 0
        remainder -= 1
    return integer_parts

def calculate_cluster_power(cluster):
    total_ram = 0
    total_speed = 0
    for node in cluster.nodes.values():
        total_ram += node.benchmarks['ram']
        total_speed += node.benchmarks['ram'] / node.benchmarks['bandwidth']
    cluster.total_ram = total_ram
    cluster.total_speed = int(total_speed)


def calculate_split_percentages(cluster, full_model_size):
    rams = cluster.rams
    if sum(rams) < full_model_size:
        raise ValueError("The sum of the cluster rams ({}) does not exceed full_model_size ({}).".format(sum(rams), full_model_size))
    total_sum = sum(rams)
    percentages = [(ram / total_sum) * 100 for ram in rams]
    rounded_percentages = round_percentages(percentages)

    splits = [full_model_size * p // 100 for p in rounded_percentages]
    remainder = full_model_size - sum(splits)
    for i in range(remainder):
        splits[i % len(splits)] += 1

    cluster.splits = splits
    cluster.assign_split_quotas_to_nodes()
    
def view_individual_cluster_details(cluster_pool):
    print('\n')
    for cluster in cluster_pool:
        print('Cluster_ID: {}  | splits: {} | RAM: {} | Speed: {}'.format(cluster.cluster_id,  
                                                            cluster.splits,
                                                            cluster.total_ram,
                                                            cluster.total_speed
                                                            ))
        

def assign_connection_targets(cluster_pool):
    temp_splits = [cluster.splits for cluster in cluster_pool]
    splits = copy.deepcopy(temp_splits)
    copy_splits = copy.deepcopy(splits)

    for cl in range(len(copy_splits)):
        if cl == len(copy_splits) - 1:
            continuous_mapping, _ = representation_converter(cluster_split=copy_splits[0],target_split=copy_splits[cl])
        else:
            continuous_mapping, _ = representation_converter(cluster_split=copy_splits[cl + 1],target_split=copy_splits[cl])
        cluster_pool[cl].inter_cluster_node_address_mappings = continuous_mapping
                
    for cl in range(len(cluster_pool)):
        current_cluster = cluster_pool[cl]
        if cl == len(cluster_pool) - 1:
            next_cluster = cluster_pool[0]
        else:
            next_cluster = cluster_pool[cl + 1]
        
        address_param_mapping_list = []
        for nid_param_mapping in current_cluster.inter_cluster_node_address_mappings:
            address_param_mapping = copy.deepcopy(nid_param_mapping)
            for cid, param_idx in nid_param_mapping.items():
                actual_nid = list(next_cluster.nodes.keys())[cid]
                del address_param_mapping[cid]
                address_param_mapping[next_cluster.nodes[actual_nid].ip_address] = param_idx
            address_param_mapping_list.append(address_param_mapping)
        current_cluster.inter_cluster_node_address_mappings = address_param_mapping_list

    for cl in range(len(cluster_pool)):
        current_cluster = cluster_pool[cl]
        cluster_ip_map = current_cluster.inter_cluster_node_address_mappings
        
        for nid, node in current_cluster.nodes.items():            
            node.next_cluster_target_node_ip_to_param_mapping = cluster_ip_map[list(current_cluster.nodes.keys()).index(nid)]

        cluster_named_inter_cluster_node_address_mappings_list = []
        for nid, node in current_cluster.nodes.items():            
            for key, val in node.next_cluster_target_node_ip_to_param_mapping.items():
                node.next_cluster_target_node_ip_to_named_param_mapping[key] = current_cluster.state_dict[val] 
            cluster_named_inter_cluster_node_address_mappings_list.append(node.next_cluster_target_node_ip_to_named_param_mapping)

        current_cluster.named_inter_cluster_node_address_mappings = cluster_named_inter_cluster_node_address_mappings_list
        
    return True

def representation_converter(cluster_split, target_split):    
    # out = []
    # total_elements = sum(cluster_split)
    cluster_index = 0 
    cluster_count = cluster_split[0]
    r = {}
    continuous_result = []
    local_result = []
    start = 0
    cumulative_size = 0
    for size in target_split:
        end = start + size
        encoded_chunk = [-1] * len(cluster_split)
        for i in range(start, end):
            if cluster_count == 0:
                cluster_index += 1
                cluster_count = cluster_split[cluster_index]
            if encoded_chunk[cluster_index] == -1:
                encoded_chunk[cluster_index] = i - start
            cluster_count -= 1
        # out.append(encoded_chunk)
        local_r = {}
        continuous_r = {}
        for i in range(len(encoded_chunk)):
            if encoded_chunk[i] != -1:
                continuous_r[i] = encoded_chunk[i] + cumulative_size
                local_r[i] = encoded_chunk[i]
        cumulative_size += size
        continuous_result.append(continuous_r)
        local_result.append(local_r)
        start = end

    return continuous_result, local_result
    # return out


class CustomTracer(Tracer):
    """
    ``Tracer`` is the class that implements the symbolic tracing functionality
    of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
    to ``Tracer().trace(m)``.
    This Tracer override the ``is_leaf_module`` function to make symbolic trace
    right in some cases.
    """
    def __init__(self, *args, customed_leaf_module=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.customed_leaf_module = customed_leaf_module

    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
        """
        A method to specify whether a given ``nn.Module`` is a "leaf" module.
        Leaf modules are the atomic units that appear in
        the IR, referenced by ``call_module`` calls. By default,
        Modules in the PyTorch standard library namespace (torch.nn)
        are leaf modules. All other modules are traced through and
        their constituent ops are recorded, unless specified otherwise
        via this parameter.
        Args:
            m (Module): The module being queried about
            module_qualified_name (str): The path to root of this module. For example,
                if you have a module hierarchy where submodule ``foo`` contains
                submodule ``bar``, which contains submodule ``baz``, that module will
                appear with the qualified name ``foo.bar.baz`` here.
        """
        if self.customed_leaf_module and isinstance(m, self.customed_leaf_module):
            return True
        
        if hasattr(m, '_is_leaf_module') and m._is_leaf_module:
            return True

        return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)


# def split_model(model, n_splits=3):
#     custom_tracer = CustomTracer()
#     split_policy = split_into_equal_size(n_splits)
#     pipe = Pipe.from_tracing(model, tracer=custom_tracer, split_policy=split_policy)
#     return pipe

def split_model_on_proportions(model, proportions=[], example_args=None, example_kwargs=None):
    traced = Pipe._trace_with_export(model, example_args=example_args, example_kwargs=example_kwargs)#torch.jit.trace(model, example_inputs=input_ids)
    split_policy = split_on_proportions(proportions)
    traced = split_policy(traced)
    pipe = Pipe._from_traced(model, traced)
    return pipe

def get_arg_index(name, submod_args):
    for i in range(len(submod_args)):
        if submod_args[i].name == name:
            return i
    return -1

def remake_proportions(proportions):
    ind = proportions.index(max(proportions))
    if ind == len(proportions) - 1:
        proportions[0] += 0.1
    else:
        proportions[ind + 1] += 0.1
    proportions[ind] -= 0.1
    return proportions

def split_model_equal(model=None, num_splits=None, proportions=None, example_args=(), example_kwargs={}, cluster_path=None, node_paths=None, model_input_node=None):

    # pipe = split_model(model, num_splits)
    pipe = split_model_on_proportions(model, proportions=proportions, example_args=example_args, example_kwargs=example_kwargs)
    
    print('Testing pipe: ')

    try:
        op = pipe.forward(*example_args, **example_kwargs)
    except Exception as e:
        proportions = remake_proportions(proportions)
        print('Remade proportions: ', proportions)
        pipe = split_model_on_proportions(model, proportions=proportions, example_args=example_args, example_kwargs=example_kwargs)
    
    print('Testing finished')
    compiled_input_dict = {}
    compiled_output_dict = {'model_inputs':{}}
    for node in pipe.split_gm.graph.nodes:
        # print(node.name, node.args)
        if 'submod' in node.name:
            input_dict = {}
            compiled_output_dict[node.name] = {}
            if node.name == 'submod_0':
                submod_0_args = node.args
                for i in range(len(submod_0_args)):
                    compiled_output_dict['model_inputs'][i] = {}
            else:
                if len(node.args) > 0:
                    for i in range(len(node.args)):
                        input_dict[i] = {}
                arg_index = 0
                for arg in node.args:
                    if 'submod' in arg.name:
                        input_dict[arg_index][arg.name] = 'placeholder:tensor'

                        if compiled_output_dict[arg.name].get(arg_index, None) is not None:
                            compiled_output_dict[arg.name][arg_index]['target'].append(node.name)   
                        else:
                            compiled_output_dict[arg.name][arg_index] = {'target' : [node.name]}


                    elif 'getitem' in arg.name:
                        inner_arg = arg.args             
                        input_dict[arg_index][inner_arg[0].name] = inner_arg[1]

                        if compiled_output_dict[inner_arg[0].name].get(inner_arg[1], None) is not None:
                            compiled_output_dict[inner_arg[0].name][inner_arg[1]]['target'].append(node.name)   
                        else:
                            compiled_output_dict[inner_arg[0].name][inner_arg[1]] = {'target' : [node.name]}

                    
                    else:
                        index = get_arg_index(arg.name, submod_0_args)
                        input_dict[arg_index]['model_inputs'] = index

                        if compiled_output_dict['model_inputs'][index].get('target', None) is not None:
                            compiled_output_dict['model_inputs'][index]['target'].append(node.name)
                        else:
                            compiled_output_dict['model_inputs'][index]['target'] = [node.name]
                    arg_index += 1
            compiled_input_dict[node.name] = input_dict

    for key, value in compiled_output_dict.items():
        if key == 'model_inputs':
            with open('{}/{}/{}.pkl'.format(cluster_path, model_input_node, key), 'wb') as file:
                pickle.dump(value,file)
        else:
            # l = key.split('_')[:-1]
            k = key.split('_')[-1]
            with open('{}/{}/{}_output.pkl'.format(cluster_path, node_paths[int(k)],key), 'wb') as file:
                pickle.dump(value,file)

    for key, value in compiled_input_dict.items():        
        # l = key.split('_')[:-1]
        k = key.split('_')[-1]
        # print('key in compiled dict: ', l, k)
        with open('{}/{}/{}_input.pkl'.format(cluster_path, node_paths[int(k)],key), 'wb') as file:
            pickle.dump(value,file)

    print('\nSubmodels are Saved in: ')        
    for key, val in pipe.split_gm._modules.items():
        script = torch.jit.script(val)
        k = key.split('_')[-1]
        print('{}/{}/{}.pt'.format(cluster_path, node_paths[int(k)], key))
        script.save('{}/{}/submod.pt'.format(cluster_path, node_paths[int(k)]))

def delete_all_folders(path):
    for folder in os.listdir(path):
        folder_path = os.path.join(path, folder)
        if os.path.isdir(folder_path):
            shutil.rmtree(folder_path)

def get_memory_reqs(model=None, input_size=None, input_data=None, depth=3):
    assert input_size is not None or input_data is not None
    if input_size is not None:
        batchsize = input_size[0]
        single_batch_input_size = (1,*input_size[1:])
        summary = torchinfo.summary(model, single_batch_input_size, depth=depth, verbose=0)
        peak_usage = summary.to_megabytes(
                        summary.total_input + 
                        summary.total_output_bytes * batchsize + 
                        summary.total_param_bytes
                    )   
    elif input_data is not None:
        summary = torchinfo.summary(model, input_data=input_data, depth=depth, verbose=0)
        peak_usage = summary.to_megabytes(
                        summary.total_input + 
                        summary.total_output_bytes + 
                        summary.total_param_bytes
                    )
    peak_usage = int(math.ceil(peak_usage))
    print('\nModel Memory: ', peak_usage)
    
    return peak_usage

[docs] def clusterize(model=None, example_args=(), example_kwargs={}, pass_data=False): """Takes the complete deep learning model and forms clusters from a pool of compute nodes defined in ``node_data/node_configs.json`` file. Automates the whole process of address sharing across nodes, reduction ring formation and seamlessly stores the results as node metadata json files for each node in ``node_data/nodes/`` folder. These metadata files are later used by ``ravnest.node.Node`` class to load all relevant attributes pertaining to a node. :param model: The complete Pytorch Model that needs to be split, defaults to None :param example_args: A sample torch tensor that the model expects as input during forward pass, defaults to () :param example_kwargs: Any extra sample inputs that the model expects passed as a dictionary, defaults to {} :param pass_data: If set to true, this performs a full forward pass with ``example_arg`` tensor to calculate the size of full model. If set to false, it will still calculate full model size using simpler mathematical techniques. Note that disabling this may not work for all models. Defaults to False :raises ValueError: If the sum of the node RAMs in a cluster does not exceed full model's size. """ path = 'node_data/' delete_all_folders(path) if len(example_args) != 0 and len(example_kwargs) == 0: if pass_data: full_model_size = get_memory_reqs(model=model, input_data=example_args[0]) else: full_model_size = get_memory_reqs(model=model, input_size=example_args[0].shape) elif len(example_args) != 0 and len(example_kwargs) != 0: frame = inspect.currentframe().f_back variable_names = {id(v): k for k, v in frame.f_locals.items()} example_args_names = [variable_names.get(id(item), None) for item in example_args] input_data = {} for n in range(len(example_args_names)): input_data[example_args_names[n]] = example_args[n] for k,v in example_kwargs.items(): input_data[k] = v full_model_size = get_memory_reqs(model=model, input_data=input_data) elif len(example_args) == 0 and len(example_kwargs) != 0: input_data = {} for k,v in example_kwargs.items(): input_data[k] = v full_model_size = get_memory_reqs(model=model, input_data=input_data) node_pool, total_ram = spawn_node_pool(mode='load_from_configs') assert total_ram > full_model_size cluster_pool = cluster_formation(full_model_size=full_model_size, node_pool=node_pool) for node in node_pool: node_path = 'node_data/cluster_{}/{}'.format(node.cluster_id, node.address) if not os.path.exists(node_path): os.makedirs(node_path) if not os.path.exists('node_data/nodes'): os.makedirs('node_data/nodes') for cluster in cluster_pool: model_input_node = cluster.nodes[list(cluster.nodes.keys())[0]].address cluster_node_ip_addresses = [] cluster_proportions = [] for node_id, metadata in cluster.nodes.items(): cluster_node_ip_addresses.append(metadata.address) if len(cluster_proportions) == len(cluster.nodes) - 1: cluster_proportions.append(1 - sum(cluster_proportions)) else: cluster_proportions.append(round(1/len(cluster.nodes), 1)) #metadata.benchmarks['ram'] / cluster.total_ram) print('cluster props: ', cluster_proportions) for i in range(len(cluster_node_ip_addresses)): for node in node_pool: if node.address == cluster_node_ip_addresses[i]: current_node = node break if i < len(cluster_node_ip_addresses) - 1: current_node.forward_target_host = cluster_node_ip_addresses[i+1].split(':')[0] current_node.forward_target_port = cluster_node_ip_addresses[i+1].split(':')[1] if i > 0 and i < len(cluster_node_ip_addresses): current_node.backward_target_host = cluster_node_ip_addresses[i-1].split(':')[0] current_node.backward_target_port = cluster_node_ip_addresses[i-1].split(':')[1] split_model_equal(model=model, proportions=cluster_proportions, example_args=example_args, example_kwargs=example_kwargs, cluster_path='node_data/cluster_{}'.format(cluster.cid), node_paths=cluster_node_ip_addresses, model_input_node = model_input_node) for node in node_pool: node.set_submodel() # print('\n Node id: ', node.node_id, ' params: ', node.submodel.state_dict().keys()) max_c = None max_l = 0 for cluster in cluster_pool: if cluster.size > max_l: max_l = cluster.size max_c = cluster print('\nNo. rings: ', max_l) rid = 0 for nid, node in max_c.nodes.items(): node.set_trainable_parameter_keys() node.ring_ids[rid] = node.trainable_param_keys[0] rid += 1 max_c.set_ringwise_params() max_ring_size = {} for cluster in cluster_pool: if cluster.cid != max_c.cid: for nid, node in cluster.nodes.items(): current_ring_id = None node.set_trainable_parameter_keys() for k in node.trainable_param_keys: if current_ring_id != max_c.all_param_to_ring[k]: node.ring_ids[max_c.all_param_to_ring[k]] = k current_ring_id = max_c.all_param_to_ring[k] for _,node in cluster.nodes.items(): print(node.ring_ids) for key in node.ring_ids: max_ring_size[key] = max_ring_size.get(key, 0) + 1 # max_ring_size_value = max(max_ring_size.values()) max_ring_size_value = len(cluster_pool) print('Max ring size: ', max_ring_size) for cl in range(len(cluster_pool)): cluster = cluster_pool[cl] if cl == len(cluster_pool) - 1: next_cluster = cluster_pool[0] else: next_cluster = cluster_pool[cl + 1] for nid, node in cluster.nodes.items(): current_address = None node.cluster_length = len(cluster.nodes) for k in node.trainable_param_keys: for n_nid, n_node in next_cluster.nodes.items(): if k in n_node.trainable_param_keys: if current_address != n_node.address: node.address_to_param[n_node.address] = k current_address = n_node.address print('\n------------------------------------------------') for node in node_pool: print(node) node_meta = {} node_meta['node_id'] = node.node_id node_meta['local_host'] = node.address.split(':')[0] node_meta['local_port'] = int(node.address.split(':')[1]) node_meta['template_path'] = 'node_data/cluster_{}/{}/'.format(node.cluster_id, node.address) node_meta['rank'] = node.cluster_id node_meta['ring_size'] = max_ring_size_value node_meta['cluster_length'] = node.cluster_length node_meta['param_addresses'] = node.address_to_param, node_meta['ring_ids'] = {int(key): value for key, value in node.ring_ids.items()} node_meta['forward_target_host'] = node.forward_target_host node_meta['forward_target_port'] = int(node.forward_target_port) if node.forward_target_port is not None else None node_meta['backward_target_host'] = node.backward_target_host node_meta['backward_target_port'] = int(node.backward_target_port) if node.backward_target_port is not None else None if node_meta['forward_target_host'] is not None and node_meta['backward_target_host'] is None: node_meta['node_type'] = 'root' elif node_meta['forward_target_host'] is None and node_meta['backward_target_host'] is not None: node_meta['node_type'] = 'leaf' elif node_meta['forward_target_host'] is not None and node_meta['backward_target_host'] is not None: node_meta['node_type'] = 'stem' else: node_meta['node_type'] = None with open('node_data/nodes/node_{}.json'.format(node.node_id), 'w') as fp: json.dump(node_meta, fp) print('\nClusters Formed Successfully!')