Skip to content

Commit fbb7be6

Browse files
committed
[doc] specify help string by keyword in configs.
This will ensure the help strings appear in the rendered docs.
1 parent 3ff628c commit fbb7be6

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

jax/_src/xla_bridge.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,24 +68,26 @@
6868
# TODO(phawkins): Remove jax_xla_backend.
6969
_XLA_BACKEND = config.string_flag(
7070
'jax_xla_backend', '',
71-
'Deprecated, please use --jax_platforms instead.')
71+
help='Deprecated, please use --jax_platforms instead.')
7272
BACKEND_TARGET = config.string_flag(
7373
'jax_backend_target',
7474
os.getenv('JAX_BACKEND_TARGET', '').lower(),
75-
'Either "local" or "rpc:address" to connect to a remote service target.')
75+
help='Either "local" or "rpc:address" to connect to a remote service target.')
7676
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
7777
_PLATFORM_NAME = config.string_flag(
7878
'jax_platform_name',
7979
os.getenv('JAX_PLATFORM_NAME', '').lower(),
80-
'Deprecated, please use --jax_platforms instead.')
80+
help='Deprecated, please use --jax_platforms instead.')
8181
CUDA_VISIBLE_DEVICES = config.string_flag(
8282
'jax_cuda_visible_devices', 'all',
83-
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
84-
'comma-separate list of integer device IDs.')
83+
help=(
84+
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
85+
'comma-separate list of integer device IDs.'))
8586
_ROCM_VISIBLE_DEVICES = config.string_flag(
8687
'jax_rocm_visible_devices', 'all',
87-
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
88-
'comma-separate list of integer device IDs.')
88+
help=(
89+
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
90+
'comma-separate list of integer device IDs.'))
8991

9092
MOCK_NUM_GPU_PROCESSES = config.int_flag(
9193
name="mock_num_gpu_processes",
@@ -127,14 +129,16 @@
127129
CROSS_HOST_TRANSFER_TIMEOUT_SECONDS = config.int_flag(
128130
"jax_cross_host_transfer_timeout_seconds",
129131
None,
130-
"Timeout for cross host transfer metadata exchange through KV store. "
131-
"Default is one minute.",
132+
help=(
133+
"Timeout for cross host transfer metadata exchange through KV store. "
134+
"Default is one minute."
135+
),
132136
)
133137

134138
CROSS_HOST_TRANSFER_TRANSFER_SIZE = config.int_flag(
135139
"jax_cross_host_transfer_transfer_size",
136140
None,
137-
"Chunk size for chunked transfer requests."
141+
help="Chunk size for chunked transfer requests."
138142
)
139143

140144
# Warn the user if they call fork(), because it's not going to go well for them.

0 commit comments

Comments
 (0)