diff --git a/include/net/netns/smc.h b/include/net/netns/smc.h index fc752a50f91b4..d66bfd3922546 100644 --- a/include/net/netns/smc.h +++ b/include/net/netns/smc.h @@ -17,6 +17,9 @@ struct netns_smc { #ifdef CONFIG_SYSCTL struct ctl_table_header *smc_hdr; #endif +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) + struct smc_hs_ctrl __rcu *hs_ctrl; +#endif /* CONFIG_SMC_HS_CTRL_BPF */ unsigned int sysctl_autocorking_size; unsigned int sysctl_smcr_buf_type; int sysctl_smcr_testlink_time; diff --git a/include/net/smc.h b/include/net/smc.h index 08bee529ed8d4..bfdc4c41f0198 100644 --- a/include/net/smc.h +++ b/include/net/smc.h @@ -17,6 +17,8 @@ #include #include +struct tcp_sock; +struct inet_request_sock; struct sock; #define SMC_MAX_PNETID_LEN 16 /* Max. length of PNET id */ @@ -50,4 +52,55 @@ struct smcd_dev { u8 going_away : 1; }; +#define SMC_HS_CTRL_NAME_MAX 16 + +enum { + /* ops can be inherit from init_net */ + SMC_HS_CTRL_FLAG_INHERITABLE = 0x1, + + SMC_HS_CTRL_ALL_FLAGS = SMC_HS_CTRL_FLAG_INHERITABLE, +}; + +struct smc_hs_ctrl { + /* private */ + + struct list_head list; + struct module *owner; + + /* public */ + + /* unique name */ + char name[SMC_HS_CTRL_NAME_MAX]; + int flags; + + /* Invoked before computing SMC option for SYN packets. + * We can control whether to set SMC options by returning various value. + * Return 0 to disable SMC, or return any other value to enable it. + */ + int (*syn_option)(struct tcp_sock *tp); + + /* Invoked before Set up SMC options for SYN-ACK packets + * We can control whether to respond SMC options by returning various + * value. Return 0 to disable SMC, or return any other value to enable + * it. + */ + int (*synack_option)(const struct tcp_sock *tp, + struct inet_request_sock *ireq); +}; + +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) +#define smc_call_hsbpf(init_val, tp, func, ...) ({ \ + typeof(init_val) __ret = (init_val); \ + struct smc_hs_ctrl *ctrl; \ + rcu_read_lock(); \ + ctrl = rcu_dereference(sock_net((struct sock *)(tp))->smc.hs_ctrl); \ + if (ctrl && ctrl->func) \ + __ret = ctrl->func(tp, ##__VA_ARGS__); \ + rcu_read_unlock(); \ + __ret; \ +}) +#else +#define smc_call_hsbpf(init_val, tp, ...) ({ (void)(tp); (init_val); }) +#endif /* CONFIG_SMC_HS_CTRL_BPF */ + #endif /* _SMC_H */ diff --git a/kernel/bpf/bpf_struct_ops.c b/kernel/bpf/bpf_struct_ops.c index a41e6730edcf3..278490683d288 100644 --- a/kernel/bpf/bpf_struct_ops.c +++ b/kernel/bpf/bpf_struct_ops.c @@ -1162,6 +1162,7 @@ bool bpf_struct_ops_get(const void *kdata) map = __bpf_map_inc_not_zero(&st_map->map, false); return !IS_ERR(map); } +EXPORT_SYMBOL_GPL(bpf_struct_ops_get); void bpf_struct_ops_put(const void *kdata) { @@ -1173,6 +1174,7 @@ void bpf_struct_ops_put(const void *kdata) bpf_map_put(&st_map->map); } +EXPORT_SYMBOL_GPL(bpf_struct_ops_put); u32 bpf_struct_ops_id(const void *kdata) { diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c index f62d61b6730a2..3ccb99e9121c5 100644 --- a/kernel/bpf/syscall.c +++ b/kernel/bpf/syscall.c @@ -1234,6 +1234,7 @@ int bpf_obj_name_cpy(char *dst, const char *src, unsigned int size) return src - orig_src; } +EXPORT_SYMBOL_GPL(bpf_obj_name_cpy); int map_check_no_btf(const struct bpf_map *map, const struct btf *btf, diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index b94efb3050d2f..a36238874b200 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -802,34 +803,36 @@ static void tcp_options_write(struct tcphdr *th, struct tcp_sock *tp, mptcp_options_write(th, ptr, tp, opts); } -static void smc_set_option(const struct tcp_sock *tp, +static void smc_set_option(struct tcp_sock *tp, struct tcp_out_options *opts, unsigned int *remaining) { #if IS_ENABLED(CONFIG_SMC) - if (static_branch_unlikely(&tcp_have_smc)) { - if (tp->syn_smc) { - if (*remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { - opts->options |= OPTION_SMC; - *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; - } + if (static_branch_unlikely(&tcp_have_smc) && tp->syn_smc) { + tp->syn_smc = !!smc_call_hsbpf(1, tp, syn_option); + /* re-check syn_smc */ + if (tp->syn_smc && + *remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { + opts->options |= OPTION_SMC; + *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; } } #endif } static void smc_set_option_cond(const struct tcp_sock *tp, - const struct inet_request_sock *ireq, + struct inet_request_sock *ireq, struct tcp_out_options *opts, unsigned int *remaining) { #if IS_ENABLED(CONFIG_SMC) - if (static_branch_unlikely(&tcp_have_smc)) { - if (tp->syn_smc && ireq->smc_ok) { - if (*remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { - opts->options |= OPTION_SMC; - *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; - } + if (static_branch_unlikely(&tcp_have_smc) && tp->syn_smc && ireq->smc_ok) { + ireq->smc_ok = !!smc_call_hsbpf(1, tp, synack_option, ireq); + /* re-check smc_ok */ + if (ireq->smc_ok && + *remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { + opts->options |= OPTION_SMC; + *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; } } #endif diff --git a/net/smc/Kconfig b/net/smc/Kconfig index 99ecd59d1f4b8..325addf83cc69 100644 --- a/net/smc/Kconfig +++ b/net/smc/Kconfig @@ -19,3 +19,13 @@ config SMC_DIAG smcss. if unsure, say Y. + +config SMC_HS_CTRL_BPF + bool "Generic eBPF hook for SMC handshake flow" + depends on SMC && BPF_SYSCALL + default y + help + SMC_HS_CTRL_BPF enables support to register generic eBPF hook for SMC + handshake flow, which offer much greater flexibility in modifying the behavior + of the SMC protocol stack compared to a complete kernel-based approach. Select + this option if you want filtring the handshake process via eBPF programs. \ No newline at end of file diff --git a/net/smc/Makefile b/net/smc/Makefile index 0e754cbc38f9c..b7f540a4ebea2 100644 --- a/net/smc/Makefile +++ b/net/smc/Makefile @@ -6,3 +6,4 @@ smc-y := af_smc.o smc_pnet.o smc_ib.o smc_clc.o smc_core.o smc_wr.o smc_llc.o smc-y += smc_cdc.o smc_tx.o smc_rx.o smc_close.o smc_ism.o smc_netlink.o smc_stats.o smc-y += smc_tracepoint.o smc_inet.o smc-$(CONFIG_SYSCTL) += smc_sysctl.o +smc-$(CONFIG_SMC_HS_CTRL_BPF) += smc_hs_bpf.o \ No newline at end of file diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c index 77b99e8ef35a4..890b36e90ba76 100644 --- a/net/smc/af_smc.c +++ b/net/smc/af_smc.c @@ -58,6 +58,7 @@ #include "smc_tracepoint.h" #include "smc_sysctl.h" #include "smc_inet.h" +#include "smc_hs_bpf.h" static DEFINE_MUTEX(smc_server_lgr_pending); /* serialize link group * creation on server @@ -3600,8 +3601,16 @@ static int __init smc_init(void) pr_err("%s: smc_inet_init fails with %d\n", __func__, rc); goto out_ulp; } + rc = bpf_smc_hs_ctrl_init(); + if (rc) { + pr_err("%s: bpf_smc_hs_ctrl_init fails with %d\n", __func__, + rc); + goto out_inet; + } static_branch_enable(&tcp_have_smc); return 0; +out_inet: + smc_inet_exit(); out_ulp: tcp_unregister_ulp(&smc_ulp_ops); out_ib: diff --git a/net/smc/smc_hs_bpf.c b/net/smc/smc_hs_bpf.c new file mode 100644 index 0000000000000..063d23d858508 --- /dev/null +++ b/net/smc/smc_hs_bpf.c @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Generic hook for SMC handshake flow. + * + * Copyright IBM Corp. 2016 + * Copyright (c) 2025, Alibaba Inc. + * + * Author: D. Wythe + */ + +#include +#include +#include +#include + +#include "smc_hs_bpf.h" + +static DEFINE_SPINLOCK(smc_hs_ctrl_list_lock); +static LIST_HEAD(smc_hs_ctrl_list); + +static int smc_hs_ctrl_reg(struct smc_hs_ctrl *ctrl) +{ + int ret = 0; + + spin_lock(&smc_hs_ctrl_list_lock); + /* already exist or duplicate name */ + if (smc_hs_ctrl_find_by_name(ctrl->name)) + ret = -EEXIST; + else + list_add_tail_rcu(&ctrl->list, &smc_hs_ctrl_list); + spin_unlock(&smc_hs_ctrl_list_lock); + return ret; +} + +static void smc_hs_ctrl_unreg(struct smc_hs_ctrl *ctrl) +{ + spin_lock(&smc_hs_ctrl_list_lock); + list_del_rcu(&ctrl->list); + spin_unlock(&smc_hs_ctrl_list_lock); + + /* Ensure that all readers to complete */ + synchronize_rcu(); +} + +struct smc_hs_ctrl *smc_hs_ctrl_find_by_name(const char *name) +{ + struct smc_hs_ctrl *ctrl; + + list_for_each_entry_rcu(ctrl, &smc_hs_ctrl_list, list) { + if (strcmp(ctrl->name, name) == 0) + return ctrl; + } + return NULL; +} + +static int __smc_bpf_stub_set_tcp_option(struct tcp_sock *tp) { return 1; } +static int __smc_bpf_stub_set_tcp_option_cond(const struct tcp_sock *tp, + struct inet_request_sock *ireq) +{ + return 1; +} + +static struct smc_hs_ctrl __smc_bpf_hs_ctrl = { + .syn_option = __smc_bpf_stub_set_tcp_option, + .synack_option = __smc_bpf_stub_set_tcp_option_cond, +}; + +static int smc_bpf_hs_ctrl_init(struct btf *btf) { return 0; } + +static int smc_bpf_hs_ctrl_reg(void *kdata, struct bpf_link *link) +{ + if (link) + return -EOPNOTSUPP; + + return smc_hs_ctrl_reg(kdata); +} + +static void smc_bpf_hs_ctrl_unreg(void *kdata, struct bpf_link *link) +{ + smc_hs_ctrl_unreg(kdata); +} + +static int smc_bpf_hs_ctrl_init_member(const struct btf_type *t, + const struct btf_member *member, + void *kdata, const void *udata) +{ + const struct smc_hs_ctrl *u_ctrl; + struct smc_hs_ctrl *k_ctrl; + u32 moff; + + u_ctrl = (const struct smc_hs_ctrl *)udata; + k_ctrl = (struct smc_hs_ctrl *)kdata; + + moff = __btf_member_bit_offset(t, member) / 8; + switch (moff) { + case offsetof(struct smc_hs_ctrl, name): + if (bpf_obj_name_cpy(k_ctrl->name, u_ctrl->name, + sizeof(u_ctrl->name)) <= 0) + return -EINVAL; + return 1; + case offsetof(struct smc_hs_ctrl, flags): + if (u_ctrl->flags & ~SMC_HS_CTRL_ALL_FLAGS) + return -EINVAL; + k_ctrl->flags = u_ctrl->flags; + return 1; + default: + break; + } + + return 0; +} + +static const struct bpf_func_proto * +bpf_smc_hs_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + return bpf_base_func_proto(func_id, prog); +} + +static const struct bpf_verifier_ops smc_bpf_verifier_ops = { + .get_func_proto = bpf_smc_hs_func_proto, + .is_valid_access = bpf_tracing_btf_ctx_access, +}; + +static struct bpf_struct_ops bpf_smc_hs_ctrl_ops = { + .name = "smc_hs_ctrl", + .init = smc_bpf_hs_ctrl_init, + .reg = smc_bpf_hs_ctrl_reg, + .unreg = smc_bpf_hs_ctrl_unreg, + .cfi_stubs = &__smc_bpf_hs_ctrl, + .verifier_ops = &smc_bpf_verifier_ops, + .init_member = smc_bpf_hs_ctrl_init_member, + .owner = THIS_MODULE, +}; + +int bpf_smc_hs_ctrl_init(void) +{ + return register_bpf_struct_ops(&bpf_smc_hs_ctrl_ops, smc_hs_ctrl); +} diff --git a/net/smc/smc_hs_bpf.h b/net/smc/smc_hs_bpf.h new file mode 100644 index 0000000000000..f5f1807c079e8 --- /dev/null +++ b/net/smc/smc_hs_bpf.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Generic hook for SMC handshake flow. + * + * Copyright IBM Corp. 2016 + * Copyright (c) 2025, Alibaba Inc. + * + * Author: D. Wythe + */ + +#ifndef __SMC_HS_CTRL +#define __SMC_HS_CTRL + +#include + +/* Find hs_ctrl by the target name, which required to be a c-string. + * Return NULL if no such ctrl was found,otherwise, return a valid ctrl. + * + * Note: Caller MUST ensure it's was invoked under rcu_read_lock. + */ +struct smc_hs_ctrl *smc_hs_ctrl_find_by_name(const char *name); + +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) +int bpf_smc_hs_ctrl_init(void); +#else +static inline int bpf_smc_hs_ctrl_init(void) { return 0; } +#endif /* CONFIG_SMC_HS_CTRL_BPF */ + +#endif /* __SMC_HS_CTRL */ diff --git a/net/smc/smc_sysctl.c b/net/smc/smc_sysctl.c index 2fab6456f7654..9185167344681 100644 --- a/net/smc/smc_sysctl.c +++ b/net/smc/smc_sysctl.c @@ -12,12 +12,14 @@ #include #include +#include #include #include "smc.h" #include "smc_core.h" #include "smc_llc.h" #include "smc_sysctl.h" +#include "smc_hs_bpf.h" static int min_sndbuf = SMC_BUF_MIN_SIZE; static int min_rcvbuf = SMC_BUF_MIN_SIZE; @@ -30,6 +32,69 @@ static int links_per_lgr_max = SMC_LINKS_ADD_LNK_MAX; static int conns_per_lgr_min = SMC_CONN_PER_LGR_MIN; static int conns_per_lgr_max = SMC_CONN_PER_LGR_MAX; +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) +static int smc_net_replace_smc_hs_ctrl(struct net *net, const char *name) +{ + struct smc_hs_ctrl *ctrl = NULL; + + rcu_read_lock(); + /* null or empty name ask to clear current ctrl */ + if (name && name[0]) { + ctrl = smc_hs_ctrl_find_by_name(name); + if (!ctrl) { + rcu_read_unlock(); + return -EINVAL; + } + /* no change, just return */ + if (ctrl == rcu_dereference(net->smc.hs_ctrl)) { + rcu_read_unlock(); + return 0; + } + if (!bpf_try_module_get(ctrl, ctrl->owner)) { + rcu_read_unlock(); + return -EBUSY; + } + } + /* xhcg old ctrl with the new one atomically */ + ctrl = unrcu_pointer(xchg(&net->smc.hs_ctrl, RCU_INITIALIZER(ctrl))); + /* release old ctrl */ + if (ctrl) + bpf_module_put(ctrl, ctrl->owner); + + rcu_read_unlock(); + return 0; +} + +static int proc_smc_hs_ctrl(const struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(ctl->data, struct net, smc.hs_ctrl); + char val[SMC_HS_CTRL_NAME_MAX]; + const struct ctl_table tbl = { + .data = val, + .maxlen = SMC_HS_CTRL_NAME_MAX, + }; + struct smc_hs_ctrl *ctrl; + int ret; + + rcu_read_lock(); + ctrl = rcu_dereference(net->smc.hs_ctrl); + if (ctrl) + memcpy(val, ctrl->name, sizeof(ctrl->name)); + else + val[0] = '\0'; + rcu_read_unlock(); + + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + if (ret) + return ret; + + if (write) + ret = smc_net_replace_smc_hs_ctrl(net, val); + return ret; +} +#endif /* CONFIG_SMC_HS_CTRL_BPF */ + static struct ctl_table smc_table[] = { { .procname = "autocorking_size", @@ -99,6 +164,15 @@ static struct ctl_table smc_table[] = { .extra1 = SYSCTL_ZERO, .extra2 = SYSCTL_ONE, }, +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) + { + .procname = "hs_ctrl", + .data = &init_net.smc.hs_ctrl, + .mode = 0644, + .maxlen = SMC_HS_CTRL_NAME_MAX, + .proc_handler = proc_smc_hs_ctrl, + }, +#endif /* CONFIG_SMC_HS_CTRL_BPF */ }; int __net_init smc_sysctl_net_init(struct net *net) @@ -109,6 +183,16 @@ int __net_init smc_sysctl_net_init(struct net *net) table = smc_table; if (!net_eq(net, &init_net)) { int i; +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) + struct smc_hs_ctrl *ctrl; + + rcu_read_lock(); + ctrl = rcu_dereference(init_net.smc.hs_ctrl); + if (ctrl && ctrl->flags & SMC_HS_CTRL_FLAG_INHERITABLE && + bpf_try_module_get(ctrl, ctrl->owner)) + rcu_assign_pointer(net->smc.hs_ctrl, ctrl); + rcu_read_unlock(); +#endif /* CONFIG_SMC_HS_CTRL_BPF */ table = kmemdup(table, sizeof(smc_table), GFP_KERNEL); if (!table) @@ -139,6 +223,9 @@ int __net_init smc_sysctl_net_init(struct net *net) if (!net_eq(net, &init_net)) kfree(table); err_alloc: +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) + smc_net_replace_smc_hs_ctrl(net, NULL); +#endif /* CONFIG_SMC_HS_CTRL_BPF */ return -ENOMEM; } @@ -148,6 +235,10 @@ void __net_exit smc_sysctl_net_exit(struct net *net) table = net->smc.smc_hdr->ctl_table_arg; unregister_net_sysctl_table(net->smc.smc_hdr); +#if IS_ENABLED(CONFIG_SMC_HS_CTRL_BPF) + smc_net_replace_smc_hs_ctrl(net, NULL); +#endif /* CONFIG_SMC_HS_CTRL_BPF */ + if (!net_eq(net, &init_net)) kfree(table); } diff --git a/tools/testing/selftests/bpf/config b/tools/testing/selftests/bpf/config index 70b28c1e653ea..fcd2f9bf78c99 100644 --- a/tools/testing/selftests/bpf/config +++ b/tools/testing/selftests/bpf/config @@ -123,3 +123,8 @@ CONFIG_XDP_SOCKETS=y CONFIG_XFRM_INTERFACE=y CONFIG_TCP_CONG_DCTCP=y CONFIG_TCP_CONG_BBR=y +CONFIG_INFINIBAND=y +CONFIG_SMC=y +CONFIG_SMC_HS_CTRL_BPF=y +CONFIG_DIBS=y +CONFIG_DIBS_LO=y \ No newline at end of file diff --git a/tools/testing/selftests/bpf/prog_tests/test_bpf_smc.c b/tools/testing/selftests/bpf/prog_tests/test_bpf_smc.c new file mode 100644 index 0000000000000..de22734abc4d2 --- /dev/null +++ b/tools/testing/selftests/bpf/prog_tests/test_bpf_smc.c @@ -0,0 +1,390 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include "network_helpers.h" +#include "bpf_smc.skel.h" + +#ifndef IPPROTO_SMC +#define IPPROTO_SMC 256 +#endif + +#define CLIENT_IP "127.0.0.1" +#define SERVER_IP "127.0.1.0" +#define SERVER_IP_VIA_RISK_PATH "127.0.2.0" + +#define SERVICE_1 80 +#define SERVICE_2 443 +#define SERVICE_3 8443 + +#define TEST_NS "bpf_smc_netns" + +static struct netns_obj *test_netns; + +struct smc_policy_ip_key { + __u32 sip; + __u32 dip; +}; + +struct smc_policy_ip_value { + __u8 mode; +}; + +#if defined(__s390x__) +/* s390x has default seid */ +static bool setup_ueid(void) { return true; } +static void cleanup_ueid(void) {} +#else +enum { + SMC_NETLINK_ADD_UEID = 10, + SMC_NETLINK_REMOVE_UEID +}; + +enum { + SMC_NLA_EID_TABLE_UNSPEC, + SMC_NLA_EID_TABLE_ENTRY, /* string */ +}; + +struct msgtemplate { + struct nlmsghdr n; + struct genlmsghdr g; + char buf[1024]; +}; + +#define GENLMSG_DATA(glh) ((void *)(NLMSG_DATA(glh) + GENL_HDRLEN)) +#define GENLMSG_PAYLOAD(glh) (NLMSG_PAYLOAD(glh, 0) - GENL_HDRLEN) +#define NLA_DATA(na) ((void *)((char *)(na) + NLA_HDRLEN)) +#define NLA_PAYLOAD(len) ((len) - NLA_HDRLEN) + +#define SMC_GENL_FAMILY_NAME "SMC_GEN_NETLINK" +#define SMC_BPFTEST_UEID "SMC-BPFTEST-UEID" + +static uint16_t smc_nl_family_id = -1; + +static int send_cmd(int fd, __u16 nlmsg_type, __u32 nlmsg_pid, + __u16 nlmsg_flags, __u8 genl_cmd, __u16 nla_type, + void *nla_data, int nla_len) +{ + struct nlattr *na; + struct sockaddr_nl nladdr; + int r, buflen; + char *buf; + + struct msgtemplate msg = {0}; + + msg.n.nlmsg_len = NLMSG_LENGTH(GENL_HDRLEN); + msg.n.nlmsg_type = nlmsg_type; + msg.n.nlmsg_flags = nlmsg_flags; + msg.n.nlmsg_seq = 0; + msg.n.nlmsg_pid = nlmsg_pid; + msg.g.cmd = genl_cmd; + msg.g.version = 1; + na = (struct nlattr *)GENLMSG_DATA(&msg); + na->nla_type = nla_type; + na->nla_len = nla_len + 1 + NLA_HDRLEN; + memcpy(NLA_DATA(na), nla_data, nla_len); + msg.n.nlmsg_len += NLMSG_ALIGN(na->nla_len); + + buf = (char *)&msg; + buflen = msg.n.nlmsg_len; + memset(&nladdr, 0, sizeof(nladdr)); + nladdr.nl_family = AF_NETLINK; + + while ((r = sendto(fd, buf, buflen, 0, (struct sockaddr *)&nladdr, + sizeof(nladdr))) < buflen) { + if (r > 0) { + buf += r; + buflen -= r; + } else if (errno != EAGAIN) { + return -1; + } + } + return 0; +} + +static bool get_smc_nl_family_id(void) +{ + struct sockaddr_nl nl_src; + struct msgtemplate msg; + struct nlattr *nl; + int fd, ret; + pid_t pid; + + fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC); + if (!ASSERT_OK_FD(fd, "nl_family socket")) + return false; + + pid = getpid(); + + memset(&nl_src, 0, sizeof(nl_src)); + nl_src.nl_family = AF_NETLINK; + nl_src.nl_pid = pid; + + ret = bind(fd, (struct sockaddr *)&nl_src, sizeof(nl_src)); + if (!ASSERT_OK(ret, "nl_family bind")) + goto fail; + + ret = send_cmd(fd, GENL_ID_CTRL, pid, + NLM_F_REQUEST, CTRL_CMD_GETFAMILY, + CTRL_ATTR_FAMILY_NAME, (void *)SMC_GENL_FAMILY_NAME, + strlen(SMC_GENL_FAMILY_NAME)); + if (!ASSERT_OK(ret, "nl_family query")) + goto fail; + + ret = recv(fd, &msg, sizeof(msg), 0); + if (!ASSERT_FALSE(msg.n.nlmsg_type == NLMSG_ERROR || ret < 0 || + !NLMSG_OK(&msg.n, ret), "nl_family response")) + goto fail; + + nl = (struct nlattr *)GENLMSG_DATA(&msg); + nl = (struct nlattr *)((char *)nl + NLA_ALIGN(nl->nla_len)); + if (!ASSERT_EQ(nl->nla_type, CTRL_ATTR_FAMILY_ID, "nl_family nla type")) + goto fail; + + smc_nl_family_id = *(uint16_t *)NLA_DATA(nl); + close(fd); + return true; +fail: + close(fd); + return false; +} + +static bool smc_ueid(int op) +{ + struct sockaddr_nl nl_src; + struct msgtemplate msg; + struct nlmsgerr *err; + char test_ueid[32]; + int fd, ret; + pid_t pid; + + /* UEID required */ + memset(test_ueid, '\x20', sizeof(test_ueid)); + memcpy(test_ueid, SMC_BPFTEST_UEID, strlen(SMC_BPFTEST_UEID)); + fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC); + if (!ASSERT_OK_FD(fd, "ueid socket")) + return false; + + pid = getpid(); + memset(&nl_src, 0, sizeof(nl_src)); + nl_src.nl_family = AF_NETLINK; + nl_src.nl_pid = pid; + + ret = bind(fd, (struct sockaddr *)&nl_src, sizeof(nl_src)); + if (!ASSERT_OK(ret, "ueid bind")) + goto fail; + + ret = send_cmd(fd, smc_nl_family_id, pid, + NLM_F_REQUEST | NLM_F_ACK, op, SMC_NLA_EID_TABLE_ENTRY, + (void *)test_ueid, sizeof(test_ueid)); + if (!ASSERT_OK(ret, "ueid cmd")) + goto fail; + + ret = recv(fd, &msg, sizeof(msg), 0); + if (!ASSERT_FALSE(ret < 0 || + !NLMSG_OK(&msg.n, ret), "ueid response")) + goto fail; + + if (msg.n.nlmsg_type == NLMSG_ERROR) { + err = NLMSG_DATA(&msg); + switch (op) { + case SMC_NETLINK_REMOVE_UEID: + if (!ASSERT_FALSE((err->error && err->error != -ENOENT), + "ueid remove")) + goto fail; + break; + case SMC_NETLINK_ADD_UEID: + if (!ASSERT_OK(err->error, "ueid add")) + goto fail; + break; + default: + break; + } + } + close(fd); + return true; +fail: + close(fd); + return false; +} + +static bool setup_ueid(void) +{ + /* get smc nl id */ + if (!get_smc_nl_family_id()) + return false; + /* clear old ueid for bpftest */ + smc_ueid(SMC_NETLINK_REMOVE_UEID); + /* smc-loopback required ueid */ + return smc_ueid(SMC_NETLINK_ADD_UEID); +} + +static void cleanup_ueid(void) +{ + smc_ueid(SMC_NETLINK_REMOVE_UEID); +} +#endif /* __s390x__ */ + +static bool setup_netns(void) +{ + test_netns = netns_new(TEST_NS, true); + if (!ASSERT_OK_PTR(test_netns, "open net namespace")) + goto fail_netns; + + SYS(fail_ip, "ip addr add 127.0.1.0/8 dev lo"); + SYS(fail_ip, "ip addr add 127.0.2.0/8 dev lo"); + + return true; +fail_ip: + netns_free(test_netns); +fail_netns: + return false; +} + +static void cleanup_netns(void) +{ + netns_free(test_netns); +} + +static bool setup_smc(void) +{ + if (!setup_ueid()) + return false; + + if (!setup_netns()) + goto fail_netns; + + return true; +fail_netns: + cleanup_ueid(); + return false; +} + +static int set_client_addr_cb(int fd, void *opts) +{ + const char *src = (const char *)opts; + struct sockaddr_in localaddr; + + localaddr.sin_family = AF_INET; + localaddr.sin_port = htons(0); + localaddr.sin_addr.s_addr = inet_addr(src); + return !ASSERT_OK(bind(fd, &localaddr, sizeof(localaddr)), "client bind"); +} + +static void run_link(const char *src, const char *dst, int port) +{ + struct network_helper_opts opts = {0}; + int server, client; + + server = start_server_str(AF_INET, SOCK_STREAM, dst, port, NULL); + if (!ASSERT_OK_FD(server, "start service_1")) + return; + + opts.proto = IPPROTO_TCP; + opts.post_socket_cb = set_client_addr_cb; + opts.cb_opts = (void *)src; + + client = connect_to_fd_opts(server, &opts); + if (!ASSERT_OK_FD(client, "start connect")) + goto fail_client; + + close(client); +fail_client: + close(server); +} + +static void block_link(int map_fd, const char *src, const char *dst) +{ + struct smc_policy_ip_value val = { .mode = /* block */ 0 }; + struct smc_policy_ip_key key = { + .sip = inet_addr(src), + .dip = inet_addr(dst), + }; + + bpf_map_update_elem(map_fd, &key, &val, BPF_ANY); +} + +/* + * This test describes a real-life service topology as follows: + * + * +-------------> service_1 + * link 1 | | + * +--------------------> server | link 2 + * | | V + * | +-------------> service_2 + * | link 3 + * client -------------------> server_via_unsafe_path -> service_3 + * + * Among them, + * 1. link-1 is very suitable for using SMC. + * 2. link-2 is not suitable for using SMC, because the mode of this link is + * kind of short-link services. + * 3. link-3 is also not suitable for using SMC, because the RDMA link is + * unavailable and needs to go through a long timeout before it can fallback + * to TCP. + * To achieve this goal, we use a customized SMC ip strategy via smc_hs_ctrl. + */ +static void test_topo(void) +{ + struct bpf_smc *skel; + int rc, map_fd; + + skel = bpf_smc__open_and_load(); + if (!ASSERT_OK_PTR(skel, "bpf_smc__open_and_load")) + return; + + rc = bpf_smc__attach(skel); + if (!ASSERT_OK(rc, "bpf_smc__attach")) + goto fail; + + map_fd = bpf_map__fd(skel->maps.smc_policy_ip); + if (!ASSERT_OK_FD(map_fd, "bpf_map__fd")) + goto fail; + + /* Mock the process of transparent replacement, since we will modify + * protocol to ipproto_smc accropding to it via + * fmod_ret/update_socket_protocol. + */ + write_sysctl("/proc/sys/net/smc/hs_ctrl", "linkcheck"); + + /* Configure ip strat */ + block_link(map_fd, CLIENT_IP, SERVER_IP_VIA_RISK_PATH); + block_link(map_fd, SERVER_IP, SERVER_IP); + + /* should go with smc */ + run_link(CLIENT_IP, SERVER_IP, SERVICE_1); + /* should go with smc fallback */ + run_link(SERVER_IP, SERVER_IP, SERVICE_2); + + ASSERT_EQ(skel->bss->smc_cnt, 2, "smc count"); + ASSERT_EQ(skel->bss->fallback_cnt, 1, "fallback count"); + + /* should go with smc */ + run_link(CLIENT_IP, SERVER_IP, SERVICE_2); + + ASSERT_EQ(skel->bss->smc_cnt, 3, "smc count"); + ASSERT_EQ(skel->bss->fallback_cnt, 1, "fallback count"); + + /* should go with smc fallback */ + run_link(CLIENT_IP, SERVER_IP_VIA_RISK_PATH, SERVICE_3); + + ASSERT_EQ(skel->bss->smc_cnt, 4, "smc count"); + ASSERT_EQ(skel->bss->fallback_cnt, 2, "fallback count"); + +fail: + bpf_smc__destroy(skel); +} + +void test_bpf_smc(void) +{ + if (!setup_smc()) { + printf("setup for smc test failed, test SKIP:\n"); + test__skip(); + return; + } + + if (test__start_subtest("topo")) + test_topo(); + + cleanup_ueid(); + cleanup_netns(); +} diff --git a/tools/testing/selftests/bpf/progs/bpf_smc.c b/tools/testing/selftests/bpf/progs/bpf_smc.c new file mode 100644 index 0000000000000..892094a09fade --- /dev/null +++ b/tools/testing/selftests/bpf/progs/bpf_smc.c @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include "vmlinux.h" + +#include +#include +#include "bpf_tracing_net.h" + +char _license[] SEC("license") = "GPL"; + +enum { + BPF_SMC_LISTEN = 10, +}; + +struct smc_sock___local { + struct sock sk; + struct smc_sock *listen_smc; + bool use_fallback; +} __attribute__((preserve_access_index)); + +int smc_cnt = 0; +int fallback_cnt = 0; + +SEC("fentry/smc_release") +int BPF_PROG(bpf_smc_release, struct socket *sock) +{ + /* only count from one side (client) */ + if (sock->sk->__sk_common.skc_state == BPF_SMC_LISTEN) + return 0; + smc_cnt++; + return 0; +} + +SEC("fentry/smc_switch_to_fallback") +int BPF_PROG(bpf_smc_switch_to_fallback, struct smc_sock___local *smc) +{ + /* only count from one side (client) */ + if (smc && !smc->listen_smc) + fallback_cnt++; + return 0; +} + +/* go with default value if no strat was found */ +bool default_ip_strat_value = true; + +struct smc_policy_ip_key { + __u32 sip; + __u32 dip; +}; + +struct smc_policy_ip_value { + __u8 mode; +}; + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(key_size, sizeof(struct smc_policy_ip_key)); + __uint(value_size, sizeof(struct smc_policy_ip_value)); + __uint(max_entries, 128); + __uint(map_flags, BPF_F_NO_PREALLOC); +} smc_policy_ip SEC(".maps"); + +static bool smc_check(__u32 src, __u32 dst) +{ + struct smc_policy_ip_value *value; + struct smc_policy_ip_key key = { + .sip = src, + .dip = dst, + }; + + value = bpf_map_lookup_elem(&smc_policy_ip, &key); + return value ? value->mode : default_ip_strat_value; +} + +SEC("fmod_ret/update_socket_protocol") +int BPF_PROG(smc_run, int family, int type, int protocol) +{ + struct task_struct *task; + + if (family != AF_INET && family != AF_INET6) + return protocol; + + if ((type & 0xf) != SOCK_STREAM) + return protocol; + + if (protocol != 0 && protocol != IPPROTO_TCP) + return protocol; + + task = bpf_get_current_task_btf(); + /* Prevent from affecting other tests */ + if (!task || !task->nsproxy->net_ns->smc.hs_ctrl) + return protocol; + + return IPPROTO_SMC; +} + +SEC("struct_ops/bpf_smc_set_tcp_option_cond") +int BPF_PROG(bpf_smc_set_tcp_option_cond, const struct tcp_sock *tp, + struct inet_request_sock *ireq) +{ + return smc_check(ireq->req.__req_common.skc_daddr, + ireq->req.__req_common.skc_rcv_saddr); +} + +SEC("struct_ops/bpf_smc_set_tcp_option") +int BPF_PROG(bpf_smc_set_tcp_option, struct tcp_sock *tp) +{ + return smc_check(tp->inet_conn.icsk_inet.sk.__sk_common.skc_rcv_saddr, + tp->inet_conn.icsk_inet.sk.__sk_common.skc_daddr); +} + +SEC(".struct_ops") +struct smc_hs_ctrl linkcheck = { + .name = "linkcheck", + .syn_option = (void *)bpf_smc_set_tcp_option, + .synack_option = (void *)bpf_smc_set_tcp_option_cond, +};