|
68 | 68 | # TODO(phawkins): Remove jax_xla_backend. |
69 | 69 | _XLA_BACKEND = config.string_flag( |
70 | 70 | 'jax_xla_backend', '', |
71 | | - 'Deprecated, please use --jax_platforms instead.') |
| 71 | + help='Deprecated, please use --jax_platforms instead.') |
72 | 72 | BACKEND_TARGET = config.string_flag( |
73 | 73 | 'jax_backend_target', |
74 | 74 | os.getenv('JAX_BACKEND_TARGET', '').lower(), |
75 | | - 'Either "local" or "rpc:address" to connect to a remote service target.') |
| 75 | + help='Either "local" or "rpc:address" to connect to a remote service target.') |
76 | 76 | # TODO(skye): warn when this is used once we test out --jax_platforms a bit |
77 | 77 | _PLATFORM_NAME = config.string_flag( |
78 | 78 | 'jax_platform_name', |
79 | 79 | os.getenv('JAX_PLATFORM_NAME', '').lower(), |
80 | | - 'Deprecated, please use --jax_platforms instead.') |
| 80 | + help='Deprecated, please use --jax_platforms instead.') |
81 | 81 | CUDA_VISIBLE_DEVICES = config.string_flag( |
82 | 82 | 'jax_cuda_visible_devices', 'all', |
83 | | - 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' |
84 | | - 'comma-separate list of integer device IDs.') |
| 83 | + help=( |
| 84 | + 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' |
| 85 | + 'comma-separate list of integer device IDs.')) |
85 | 86 | _ROCM_VISIBLE_DEVICES = config.string_flag( |
86 | 87 | 'jax_rocm_visible_devices', 'all', |
87 | | - 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' |
88 | | - 'comma-separate list of integer device IDs.') |
| 88 | + help=( |
| 89 | + 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' |
| 90 | + 'comma-separate list of integer device IDs.')) |
89 | 91 |
|
90 | 92 | MOCK_NUM_GPU_PROCESSES = config.int_flag( |
91 | 93 | name="mock_num_gpu_processes", |
|
127 | 129 | CROSS_HOST_TRANSFER_TIMEOUT_SECONDS = config.int_flag( |
128 | 130 | "jax_cross_host_transfer_timeout_seconds", |
129 | 131 | None, |
130 | | - "Timeout for cross host transfer metadata exchange through KV store. " |
131 | | - "Default is one minute.", |
| 132 | + help=( |
| 133 | + "Timeout for cross host transfer metadata exchange through KV store. " |
| 134 | + "Default is one minute." |
| 135 | + ), |
132 | 136 | ) |
133 | 137 |
|
134 | 138 | CROSS_HOST_TRANSFER_TRANSFER_SIZE = config.int_flag( |
135 | 139 | "jax_cross_host_transfer_transfer_size", |
136 | 140 | None, |
137 | | - "Chunk size for chunked transfer requests." |
| 141 | + help="Chunk size for chunked transfer requests." |
138 | 142 | ) |
139 | 143 |
|
140 | 144 | # Warn the user if they call fork(), because it's not going to go well for them. |
|
0 commit comments