Skip to content

Use bfloat16 for eval #66

@tbaker2

Description

@tbaker2

I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval. I modified the LmCloudSpmd2B configuration with the following lines:

MODEL_DTYPE = jnp.bfloat16
ICI_MESH_SHAPE = [1, 1, 1]

The training status output includes the following output.

model.dtype : type/jax.numpy/float32
model.fprop_dtype : dtype[bfloat16]

All of the other operator datatypes are float32. When I run that model with the --eval switch all of the computation is in float32. How can I direct paxml to use bfloat16?

Tom

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions