block

Encoder

class model_center.layer.Encoder(num_layers: int, dim_model: int, dim_ff: int, num_heads: int, dim_head: int, dtype: torch.dtype = torch.float16, int8: bool = False, norm_init_var: float = 1.0, norm_bias: bool = False, norm_eps: float = 1e-05, att_init_mean: float = 0.0, att_init_std: float = 0.02, att_bias: bool = False, att_mask_value: float = - inf, ffn_init_mean: float = 0.0, ffn_init_std: float = 0.02, ffn_bias: bool = False, ffn_activate_fn: str = 'gated_gelu', pos_bias_type: str = 'none', post_layer_norm: bool = False, length_scale: bool = False, attn_scale: bool = False, dropout_p: float = 0, parallel_ffn: bool = False)

Bases: torch.nn.modules.module.Module

Layers of encoder transformer blocks plus an final layernorm.

Parameters
forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None)
Parameters
  • hidden-states (torch.Tensor of shape (batch, seq_enc, dim_model)) – Input of encoder, might be the embedding of a batch of sequences.

  • attention_mask (torch.Tensor of shape (batch, seq_enc, seq_enc)) – Avoid invalid areas to participate in the calculation

  • position_bias (torch.Tensor of shape (num_heads, seq_enc, seq_enc)) –

Returns

The encoder output.

Return type

torch.Tensor of shape (batch, seq_enc, dim_model)

Decoder

class model_center.layer.Decoder(num_layers: int, dim_model: int, dim_ff: int, num_heads: int, dim_head: int, dtype: torch.dtype = torch.float16, int8: bool = False, norm_init_var: float = 1.0, norm_bias: bool = False, norm_eps: float = 1e-05, att_init_mean: float = 0.0, att_init_std: float = 0.02, att_bias: bool = False, att_mask_value: float = - inf, ffn_init_mean: float = 0.0, ffn_init_std: float = 0.02, ffn_bias: bool = False, ffn_activate_fn: str = 'gated_gelu', pos_bias_type: str = 'none', length_scale: bool = False, attn_scale: bool = False, dropout_p: float = 0, parallel_ffn: bool = False)

Bases: torch.nn.modules.module.Module

Layers of decoder transformer blocks plus an final layernorm.

Parameters
forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: torch.Tensor, cross_hidden_states=None, cross_attention_mask=None, cross_position_bias=None)
Parameters
  • hidden_states (torch.Tensor of shape (batch, seq_dec, dim_model)) – Input of decoder, Can be the embedding of a batch of sequences.

  • attention_mask (torch.Tensor of shape (batch, seq_dec, seq_dec)) – Avoid invalid areas to participate in the calculation.

  • position_bias (torch.Tensor of shape (num_heads, seq_dec, seq_dec)) –

  • cross_hidden_states (torch.Tensor of shape (batch, seq_enc, dim_model)) – Input of decoder, Can be the output of encoder.

  • cross_attention_mask (torch.Tensor of shape (batch, seq_dec, seq_enc)) – Avoid invalid areas to participate in the calculation when the output of encoder participates in the calculation.

  • cross_position_bias (torch.Tensor of shape (num_heads, seq_dec, seq_enc)) –

Returns

The decoder output.

Return type

torch.Tensor of shape (batch, seq_dec, dim_model)

TransformerBlock

class model_center.layer.TransformerBlock(dim_model: int, dim_ff: int, num_heads: int, dim_head: int, is_decoder: bool = False, dtype=torch.float16, int8=False, norm_init_var: float = 1.0, norm_bias: bool = False, norm_eps: float = 1e-05, att_init_mean: float = 0.0, att_init_std: float = 0.02, att_bias: bool = False, att_mask_value: float = - inf, ffn_init_mean: float = 0.0, ffn_init_std: float = 0.02, ffn_bias: bool = False, ffn_activate_fn: str = 'gated_gelu', pos_bias_type: str = 'none', post_layer_norm: bool = False, parallel_ffn: bool = False, length_scale: bool = False, attn_scale: bool = False, dropout_p: float = 0)

Bases: torch.nn.modules.module.Module

The whole transformer block. A sequence of operation. Consists of self-attention block[, cross-attention block] and feed-forward block.

Parameters
forward(self_hidden_states: torch.Tensor, self_attention_mask: torch.Tensor, self_position_bias: Optional[torch.Tensor] = None, cross_hidden_states=None, cross_attention_mask=None, cross_position_bias=None)
Parameters
  • self_hidden_states (torch.Tensor of shape (batch, seq_self, dim_model)) – Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.

  • self_attention_mask (torch.Tensor of shape (batch, seq_self, seq_self)) – Avoid invalid areas to participate in the calculation of self-attention.

  • self_position_bias (torch.Tensor of shape (num_heads, seq_self, seq_self)) – Provide positional information to self-attention block.

  • cross_hidden_states (torch.Tensor of shape (batch, seq_cross, dim_model)) – Input of cross-attention block.

  • cross_attention_mask (torch.Tensor of shape (batch, seq_self, seq_cross)) – Avoid invalid areas to participate in the calculation of cross-attention.

  • cross_position_bias (torch.Tensor of shape (num_heads, seq_self, seq_cross)) – Provide positional information to cross-attention block.

Returns

The output of transformer block.

Return type

torch.Tensor of shape (batch, seq_self, dim_model)

FFNBlock

class model_center.layer.FFNBlock(dim_model: int, dim_ff: int, dtype=torch.float16, int8=False, norm_init_var: float = 1.0, norm_bias: bool = False, norm_eps: float = 1e-05, ffn_init_mean: float = 0.0, ffn_init_std: float = 0.02, ffn_bias: bool = False, ffn_activate_fn: str = 'gated_gelu', post_layer_norm: bool = False, length_scale: bool = False, dropout_p: float = 0)

Bases: torch.nn.modules.module.Module

The whole feed-forward block. A sequence of operation. Consists of layernorm, feed-forward and residual connection.

Parameters
forward(hidden_states: torch.Tensor)
Parameters

hidden_states (torch.Tensor of shape (batch, seq_self, dim_model)) – Hidden states before feed forward layer.

Returns

The output of feed-forward block

Return type

torch.Tensor of shape (batch, seq_self, dim_model)

SelfAttentionBlock

class model_center.layer.SelfAttentionBlock(dim_model: int, num_heads: int, dim_head: int, dtype=torch.float16, int8=False, norm_init_var: float = 1.0, norm_bias: bool = False, norm_eps: float = 1e-05, att_init_mean: float = 0.0, att_init_std: float = 0.02, att_bias: bool = False, att_mask_value: float = - inf, pos_bias_type: str = 'none', post_layer_norm: bool = False, length_scale: bool = False, attn_scale: bool = False, dropout_p: float = 0)

Bases: torch.nn.modules.module.Module

The whole cross-attention block. A sequence of operation. Consists of layernorm, self-attention and residual connection.

Parameters
forward(hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None)
Parameters
  • hidden_states (torch.Tensor of shape (batch, seq_self, dim_model)) – Input of self-attention block. It can be the embedding of a batch of sequences.

  • attention_mask (torch.Tensor of shape (batch, seq_self, seq_self)) – Avoid invalid areas to participate in the calculation.

  • position_bias (torch.Tensor of shape (num_heads, seq_self, seq_self)) – Provide positional information to self-attention block.

Returns

The output of attention block.

Return type

torch.Tensor of shape (batch, seq_self, dim_model)

CrossAttentionBlock

class model_center.layer.CrossAttentionBlock(dim_model: int, num_heads: int, dim_head: int, dtype=torch.float16, int8=False, norm_init_var: float = 1.0, norm_bias: bool = False, norm_eps: float = 1e-05, att_init_mean: float = 0.0, att_init_std: float = 0.02, att_bias: bool = False, att_mask_value: float = - inf, pos_bias_type: str = 'none', post_layer_norm: bool = False, length_scale: bool = False, attn_scale: bool = False, dropout_p: float = 0)

Bases: torch.nn.modules.module.Module

The whole cross-attention block. A sequence of operation. Consists of layernorm, cross-attention and residual connection.

Parameters
forward(hidden_states: torch.Tensor, key_value_states: torch.Tensor, attention_mask: torch.Tensor, position_bias: Optional[torch.Tensor] = None)
Parameters
  • hidden_states (torch.Tensor of shape (batch, seq_self, dim_model)) – Input of cross-attention block. It can be seen as query in the coming self-attention operation.

  • key_value_states (torch.Tensor of shape (batch, seq_cross, dim_model)) – Used as key_value in coming self_attention operation.

  • attention_mask (torch.Tensor of shape (batch, seq_self, seq_cross)) – Avoid invalid areas to participate in the calculation.

  • position_bias (torch.Tensor of shape (num_heads, seq_self, seq_cross)) – Provide positional information to self-attention block.

Returns

The output of cross-attention block.

Return type

torch.Tensor of shape (batch, seq_self, dim_model)