How to write a new model
Model Implementation
We implement our models in model_center/model
We provided commonly used modules in model_center/layer
, such as Linear
, LayerNorm
, Embedding
,
which are implemented based on bmtrain.DistributedParameter
and bmtrain.DistributedModule, for distributed training support.
We have also implemented common ways of combining modules in model_center/layer
, which are block.
For example, SelfAttentionBlock
combines Layernorm, Attention, Add&Norm together.
Each blocks has diverse option, e.g., FFNBlock
supports gated_relu
, relu
, gated_gelu
, gelu
; blocks support pre-layernorm and post-layernorm.
With the help of these commonly used modules we provided, a new model can be written easily without many exceptions. You can just add the model specific feature into the common structure.
A classic transformer is implemented in the following structure:
We use bmtrain.CheckpointBlock, and bmtrain.TransformerBlockList to wrap our transformer blocks. These reducd the GPU memory usage by a great amount without adding lots of computation time. For more information, see BMTrain’s Quick Start
T5(
(input_embedding): Embedding()
(position_bias_enc): RelativePositionEmbedding()
(position_bias_dec): RelativePositionEmbedding()
(encoder): Encoder(
(layers): bmtrain.TransformerBlockList(
(0): bmtrain.CheckpointBlock(
TransformerBlock(
(self_att): SelfAttentionBlock(
(layernorm_before_attention): LayerNorm()
(attention): Attention(
(project_q): Linear()
(project_k): Linear()
(project_v): Linear()
(attention_out): Linear()
)
)
(ffn): FFNBlock(
(layernorm_before_ffn): LayerNorm()
(ffn): FeedForward(
(w_in): DenseACT(
(w): Linear()
(act): ReLU()
)
(w_out): Linear()
)
)
)
)
(1): bmtrain.CheckpointBlock()
.
.
.
)
(output_layernorm): LayerNorm()
)
(decoder): Decoder(
(layers): bmtrain.TransformerBlockList(
(0): bmtrain.CheckpointBlock(
(self_att): SelfAttentionBlock(
(layernorm_before_attention): LayerNorm()
(attention): Attention(
(project_q): Linear()
(project_k): Linear()
(project_v): Linear()
(attention_out): Linear()
)
)
(cross_att): CrossAttentionBlock(
(layernorm_before_attention): LayerNorm()
(attention): Attention(
(project_q): Linear()
(project_k): Linear()
(project_v): Linear()
(attention_out): Linear()
)
)
(ffn): FFNBlock(
(layernorm_before_ffn): LayerNorm()
(ffn): FeedForward(
(w_in): DenseACT(
(w): Linear()
(act): ReLU()
)
(w_out): Linear()
)
)
)
(1): bmtrain.CheckpointBlock()
.
.
.
)
(output_layernorm): LayerNorm()
)
(output_projection): Linear(
(weight): bmtrain.DistributedParameter()
(bias): bmtrain.DistributedParameter()
)
)
Model Config
We add model configs in model_center/model/config
By inheriting model_center.config.Config
, config class can parse json files with config.from_json_file(path)
method,
the parsed json file are then save to the config class and used by model by instantiating model with model(config)
.