File size: 3,407 Bytes
			
			| d73e1e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import logging
import os
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from .norm import SimpleRMSNorm
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("print_config")
BASE_DIM = 256
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True
def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()
def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()
def is_main_process():
    return get_rank() == 0
def logging_info(string):
    if is_main_process():
        logger.info(string)
def print_params(**kwargs):
    if is_main_process():
        logger.info(f"start print config of {kwargs['__class__']}")
        for key in kwargs:
            if key in ["__class__", "self"]:
                continue
            logger.info(f"{key}: {kwargs[key]}")
        logger.info(f"end print config of {kwargs['__class__']}")
def print_config(config):
    if is_main_process():
        logger.info(f"start print config of {config['__class__']}")
        for key in config:
            if key in ["__class__", "self"]:
                continue
            logger.info(f"{key}: {config[key]}")
        logger.info(f"end print config of {config['__class__']}")
def print_module(module):
    named_modules = set()
    for p in module.named_modules():
        named_modules.update([p[0]] )    
    named_modules = list(named_modules)
    string_repr = ''
    for p in module.named_parameters():
        name = p[0].split('.')[0]
        if name not in named_modules:
            string_repr = string_repr + '('+ name +'): ' \
                +'Tensor(' + str(tuple(p[1].shape))+ ', requires_grad='+ str(p[1].requires_grad) +')\n' 
    
    return string_repr.rstrip("\n")
def get_activation_fn(activation):
    logger.info(f"activation: {activation}")
    if activation == "gelu":
        return F.gelu
    elif activation == "relu":
        return F.relu
    elif activation == "elu":
        return F.elu
    elif activation == "sigmoid":
        return F.sigmoid
    elif activation == "exp":
        def f(x):
            with torch.no_grad():
                x_max = torch.max(x, dim=-1, keepdims=True).values
            y = torch.exp(x - x_max)
            
            return y
        return f
    elif activation == "leak":
        return F.leaky_relu
    elif activation == "1+elu":
        def f(x):
            return 1 + F.elu(x)
        return f
    elif activation == "2+elu":
            def f(x):
                return 2 + F.elu(x)
            return f
    elif activation == "silu" or activation == "swish":
        return F.silu
    elif activation == "sine":
        return torch.sin
    else:
        logger.info(f"activation: does not support {activation}, use Identity!!!")
        return lambda x: x
    
def get_norm_fn(norm_type):
    if norm_type == "simplermsnorm":
        return SimpleRMSNorm
    else:
        return nn.LayerNorm
def convert_to_multiple_of_base(x):
    return BASE_DIM * ((x + BASE_DIM - 1) // BASE_DIM) | 
