Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions praxis/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@

# Postfix for quantized scale and zero point names.
QUANTIZED_SCALE_NAME_POSTFIX = '_quantized_scale'
QUANTIZED_SCALE_ACT_NAME_POSTFIX = '_quantized_act_scale'
QUANTIZED_ZP_NAME_POSTFIX = '_quantized_zp'

# Postfix for sparsity mask
Expand Down Expand Up @@ -2294,13 +2295,20 @@ def create_quantized_variable(
if scale_hparams is None:
scale_hparams = WeightHParams(shape=scale_shape)
else:
if len(scale_shape) > 0:
raise ValueError('Should either scale_shape or scale_hparams, not both')
pass
self.create_variable(name=name, var_hparams=quantized_weight_hparams)
self.create_variable(
name=name + QUANTIZED_SCALE_NAME_POSTFIX,
var_hparams=scale_hparams,
)
dtype = weight_hparams.dtype
if (jax.dtypes.scalar_type_of(dtype) == float
and jnp.finfo(dtype).bits == 8
):
self.create_variable(
name=name + QUANTIZED_SCALE_ACT_NAME_POSTFIX,
var_hparams=scale_hparams,
)
if not use_symmetric:
self.create_variable(
name=name + QUANTIZED_ZP_NAME_POSTFIX,
Expand Down Expand Up @@ -2335,6 +2343,25 @@ def get_quantized_weight(
zp = None if use_symmetric else self.theta[zp_name]
return self.theta[name], self.theta[scale_name], zp

@nn.nowrap
def get_quantized_act_scale(
self, name: str,
) -> tuple[JTensor, JTensor, JTensor | None]:
"""Gets quantized activation scale.

`name` will be name of the weight tensor; assumes scale and zero point
tensor have the postfix, `_quantized_act_scale`.

Args:
name: Variable name for the weight tensor.

Returns:
Activation scale Tensor.
"""

scale_act_name = name + QUANTIZED_SCALE_ACT_NAME_POSTFIX
return self.theta[scale_act_name]

@nn.nowrap
def create_sparse_variable(self, name: str, weight_hparams: WeightHParams):
"""Creates the weight and mask tensors for sparse variables.
Expand Down
6 changes: 3 additions & 3 deletions praxis/layers/quantization/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _get_weight_scale_shape(self, block_size, use_block_size):
else:
weight_shape = [self.input_dim] + hd_shape

scale_shape = [self.input_dim] if self.is_output_projection else hd_shape
scale_shape = [1]

if block_size > 0 and use_block_size:
eqn = self._get_eqn()
Expand Down Expand Up @@ -546,9 +546,9 @@ def setup(self) -> None:
self.set_up_weights(
weight_name='w',
weight_params=pc,
scale_shape=[3] + hd_shape,
scale_shape=[1],
)
self.create_sparsity_variables('w', pc, scale_shape=[3] + hd_shape)
self.create_sparsity_variables('w', pc, scale_shape=[1])
if self.use_bias:
# Combined bias weight for q, k, v projections.
pc_bias = WeightHParams(
Expand Down
5 changes: 3 additions & 2 deletions praxis/layers/quantization/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _get_weight_hparams(
"""
wp = self.weight_split_dims_mapping
weight_shape = [self.input_dims, self.output_dims]
scale_shape = [self.output_dims]
scale_shape = [1]
block_size = self._sub_channel_block_size()
if using_sub_channel:
weight_shape = self._get_sub_channel_shape(weight_shape, block_size, 0)
Expand Down Expand Up @@ -190,11 +190,12 @@ def setup(self) -> None:
weight_name='w',
weight_params=weight_hparams,
scale_hparams=scale_hparams,
scale_shape=[1],
)
self.create_sparsity_variables(
'w',
weight_hparams,
scale_shape=[self.output_dims],
scale_shape=[1],
)

def __call__(self, inputs: JTensor) -> JTensor:
Expand Down
16 changes: 13 additions & 3 deletions praxis/layers/quantization/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from praxis.layers.quantization import optimization
from praxis.layers.quantization import quantization_hparams
from praxis.layers.quantization import utils
from flax.linen import fp8_ops


JTensor = pytypes.JTensor
Expand Down Expand Up @@ -601,11 +602,12 @@ def einsum(
w_dequantized = _dequantize(w, scale, zp, eqn_to_weight_contract_dims(eqn))
return jnp.einsum(eqn, x_dequantized, w_dequantized)

use_fp8 = False
if (
jax.dtypes.scalar_type_of(w.dtype) == float
and jnp.finfo(w.dtype).bits == 8
):
w = w.astype(jnp.bfloat16)
use_fp8 = True # w stay as fp8

if x.dtype in INT_TYPES and w.dtype in INT_TYPES:
assert not swap_xw, 'No need to swap x and w when both are int types.'
Expand All @@ -626,10 +628,18 @@ def einsum(
if swap_xw:
ret = jnp.einsum(eqn_normalized, w, x)
else:
ret = jnp.einsum(eqn_normalized, x, w)
dot_general_with_precision = lambda lhs, rhs, dimension_numbers, \
precision=None, preferred_element_type=jnp.bfloat16: lax.dot_general(
lhs,
rhs,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=jnp.bfloat16, #TODO: use proper type
)
ret = jnp.einsum(eqn_normalized, x, w, preferred_element_type=jnp.bfloat16)

if scale_act is not None:
if scale_act.ndim == 0:
if scale_act.ndim == 0 or use_fp8:
scale *= scale_act
else:
ret *= jnp.expand_dims(scale_act, _get_expand_dims_lhs(eqn))
Expand Down
53 changes: 32 additions & 21 deletions praxis/layers/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
QuantizationParams = quantization_hparams.QuantizationParams
instance_field = base_layer.instance_field
WeightQuantizationParams = quantization_hparams.WeightQuantizationParams
QUANTIZED_SCALE_ACT_NAME_POSTFIX = '_quantized_act_scale'


class QuantizationLayer(base_layer.BaseLayer):
Expand Down Expand Up @@ -131,6 +132,7 @@ def set_up_weights(
jax.dtypes.scalar_type_of(dtype) == float
and jnp.finfo(dtype).bits == 8
):
weight_params.dtype=dtype
dtype = jnp.int8
self.create_quantized_variable(
weight_name,
Expand Down Expand Up @@ -219,6 +221,14 @@ def quantized_einsum(
else:
return jnp.einsum(eqn, x, w)

use_fp8 = False
dtype = self.quantization.weight_params.dtype
if (
jax.dtypes.scalar_type_of(dtype) == float
and jnp.finfo(dtype).bits == 8
):
use_fp8 = True

# Optionally create step count.
step_count = None
if self.quantization.weight_params.use_step_count:
Expand Down Expand Up @@ -249,28 +259,29 @@ def quantized_einsum(
x = x.astype(jnp.int8)
logging.info('Static activation quantization is not supported yet.')
elif self.quantization.act_params is not None:
act_params = self.quantization.act_params
x, scale_act, zp_act = operations.reduce_einsum_activation_precision(
eqn,
x,
bits=act_params.precision,
per_channel=act_params.per_channel,
symmetric=act_params.symmetric,
percentile=act_params.clipping_coeff,
)
if act_params.precision <= 8:
if act_params.symmetric:
# TODO(rybakov): add support for asymmetric too.
x = x.astype(jnp.int8)

dtype = self.quantization.weight_params.dtype
if (
jax.dtypes.scalar_type_of(dtype) == float
and jnp.finfo(dtype).bits == 8
):
if not use_fp8:
act_params = self.quantization.act_params
x, scale_act, zp_act = operations.reduce_einsum_activation_precision(
eqn,
x,
bits=act_params.precision,
per_channel=act_params.per_channel,
symmetric=act_params.symmetric,
percentile=act_params.clipping_coeff,
)
if act_params.precision <= 8 and act_params.symmetric:
if act_params.symmetric:
# TODO(rybakov): add support for asymmetric too.
x = x.astype(jnp.int8)
else:
# per-tensor quant
scale_act = self.get_quantized_act_scale(weight_name)
x = fp8_ops_linen.quantize(x, jnp.float8_e4m3fn, scale_act, jnp.bfloat16)

if use_fp8: #
# cast from int8 to fp8
w = jax.lax.bitcast_convert_type(w, dtype)
# cast to bf16 since bf16 x fp8 is not supported.
w = w.astype(jnp.bfloat16)

out = operations.einsum(
eqn,
x,
Expand Down