From 994c1fd7d0bfa17ac6ef6a0fc83457d159aaf1ad Mon Sep 17 00:00:00 2001 From: Orion Poplawski Date: Thu, 20 Nov 2025 15:10:34 -0700 Subject: [PATCH] [style] Standardize on double-quotes --- examples/ipsec/filter_plugins/pfsense.py | 144 +- plugins/lookup/pfsense.py | 2008 +++++++++++------ plugins/module_utils/__impl/addresses.py | 116 +- plugins/module_utils/__impl/checks.py | 64 +- plugins/module_utils/__impl/interfaces.py | 97 +- plugins/module_utils/alias.py | 112 +- plugins/module_utils/arg_route.py | 8 +- plugins/module_utils/default_gateway.py | 75 +- plugins/module_utils/dhcp_server.py | 440 ++-- plugins/module_utils/gateway.py | 355 ++- plugins/module_utils/haproxy_backend.py | 269 ++- .../module_utils/haproxy_backend_server.py | 381 ++-- plugins/module_utils/interface.py | 626 +++-- plugins/module_utils/interface_group.py | 158 +- plugins/module_utils/ipsec.py | 615 +++-- plugins/module_utils/ipsec_p2.py | 676 ++++-- plugins/module_utils/ipsec_proposal.py | 192 +- plugins/module_utils/module_base.py | 403 ++-- plugins/module_utils/module_config_base.py | 62 +- plugins/module_utils/nat_outbound.py | 586 +++-- plugins/module_utils/nat_port_forward.py | 528 +++-- plugins/module_utils/openvpn_client.py | 386 ++-- plugins/module_utils/openvpn_override.py | 266 ++- plugins/module_utils/openvpn_server.py | 592 +++-- plugins/module_utils/pfsense.py | 510 +++-- plugins/module_utils/route.py | 152 +- plugins/module_utils/rule.py | 975 +++++--- plugins/module_utils/rule_separator.py | 138 +- plugins/module_utils/vlan.py | 208 +- plugins/modules/pfsense_aggregate.py | 484 ++-- plugins/modules/pfsense_alias.py | 21 +- plugins/modules/pfsense_authserver_ldap.py | 222 +- plugins/modules/pfsense_authserver_radius.py | 117 +- plugins/modules/pfsense_ca.py | 477 ++-- plugins/modules/pfsense_cert.py | 333 ++- plugins/modules/pfsense_default_gateway.py | 21 +- plugins/modules/pfsense_dhcp_server.py | 20 +- plugins/modules/pfsense_dhcp_static.py | 292 ++- plugins/modules/pfsense_dns_resolver.py | 449 +++- plugins/modules/pfsense_gateway.py | 20 +- plugins/modules/pfsense_group.py | 126 +- plugins/modules/pfsense_haproxy_backend.py | 20 +- .../modules/pfsense_haproxy_backend_server.py | 14 +- plugins/modules/pfsense_interface.py | 16 +- plugins/modules/pfsense_interface_group.py | 16 +- plugins/modules/pfsense_ipsec.py | 20 +- plugins/modules/pfsense_ipsec_aggregate.py | 243 +- plugins/modules/pfsense_ipsec_p2.py | 20 +- plugins/modules/pfsense_ipsec_proposal.py | 16 +- plugins/modules/pfsense_log_settings.py | 321 ++- plugins/modules/pfsense_nat_outbound.py | 14 +- plugins/modules/pfsense_nat_port_forward.py | 16 +- plugins/modules/pfsense_openvpn_client.py | 20 +- plugins/modules/pfsense_openvpn_override.py | 16 +- plugins/modules/pfsense_openvpn_server.py | 20 +- plugins/modules/pfsense_phpshell.py | 47 +- plugins/modules/pfsense_rewrite_config.py | 41 +- plugins/modules/pfsense_route.py | 20 +- plugins/modules/pfsense_rule.py | 20 +- plugins/modules/pfsense_rule_separator.py | 30 +- plugins/modules/pfsense_setup.py | 696 ++++-- plugins/modules/pfsense_shellcmd.py | 81 +- plugins/modules/pfsense_user.py | 191 +- plugins/modules/pfsense_vlan.py | 20 +- pyproject.toml | 4 + tests/unit/plugins/lookup/test_pfsense.py | 254 ++- .../unit/plugins/module_utils/test_pfsense.py | 29 +- tests/unit/plugins/modules/pfsense_module.py | 322 ++- .../plugins/modules/test_pfsense_aggregate.py | 453 ++-- .../plugins/modules/test_pfsense_alias.py | 460 ++-- .../modules/test_pfsense_alias_null.py | 38 +- .../modules/test_pfsense_authserver_ldap.py | 202 +- .../modules/test_pfsense_authserver_radius.py | 115 +- tests/unit/plugins/modules/test_pfsense_ca.py | 101 +- .../unit/plugins/modules/test_pfsense_cert.py | 73 +- .../modules/test_pfsense_dhcp_server.py | 198 +- .../modules/test_pfsense_dhcp_static.py | 187 +- .../modules/test_pfsense_dns_resolver.py | 145 +- .../plugins/modules/test_pfsense_gateway.py | 225 +- .../modules/test_pfsense_haproxy_backend.py | 105 +- .../test_pfsense_haproxy_backend_server.py | 237 +- .../plugins/modules/test_pfsense_interface.py | 363 ++- .../modules/test_pfsense_interface_group.py | 96 +- .../plugins/modules/test_pfsense_ipsec.py | 340 ++- .../modules/test_pfsense_ipsec_aggregate.py | 287 ++- .../plugins/modules/test_pfsense_ipsec_p2.py | 542 +++-- .../modules/test_pfsense_ipsec_proposal.py | 130 +- .../modules/test_pfsense_log_settings.py | 620 ++--- .../modules/test_pfsense_nat_outbound.py | 369 ++- .../modules/test_pfsense_nat_port_forward.py | 293 ++- .../modules/test_pfsense_openvpn_override.py | 99 +- .../modules/test_pfsense_openvpn_server.py | 264 ++- .../plugins/modules/test_pfsense_route.py | 90 +- .../unit/plugins/modules/test_pfsense_rule.py | 230 +- .../modules/test_pfsense_rule_create.py | 1009 +++++++-- .../plugins/modules/test_pfsense_rule_misc.py | 22 +- .../plugins/modules/test_pfsense_rule_noop.py | 315 ++- .../modules/test_pfsense_rule_separator.py | 111 +- .../modules/test_pfsense_rule_update.py | 632 +++++- .../plugins/modules/test_pfsense_setup.py | 279 ++- .../unit/plugins/modules/test_pfsense_user.py | 80 +- .../unit/plugins/modules/test_pfsense_vlan.py | 103 +- 102 files changed, 16308 insertions(+), 8756 deletions(-) create mode 100644 pyproject.toml diff --git a/examples/ipsec/filter_plugins/pfsense.py b/examples/ipsec/filter_plugins/pfsense.py index 4106e6de..69ac13d8 100644 --- a/examples/ipsec/filter_plugins/pfsense.py +++ b/examples/ipsec/filter_plugins/pfsense.py @@ -3,17 +3,18 @@ # Copyright: (c) 2019, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type from ansible.errors import AnsibleFilterError def format_ipsec_aggregate_ipsecs(all_tunnels, pfname): - """ format ipsecs for format_ipsec_aggregate """ + """format ipsecs for format_ipsec_aggregate""" res = list() for name, ipsec in all_tunnels.items(): - pfsenses = ipsec['pfsenses'] + pfsenses = ipsec["pfsenses"] if pfname not in pfsenses: continue local = pfsenses[pfname] @@ -25,49 +26,51 @@ def format_ipsec_aggregate_ipsecs(all_tunnels, pfname): params = dict() res.append(params) - params['descr'] = name + ' to ' + remote_name - params['state'] = 'present' + params["descr"] = name + " to " + remote_name + params["state"] = "present" for option in ipsec: - if option in ['pfsenses', 'phase1', 'phase2']: + if option in ["pfsenses", "phase1", "phase2"]: continue params[option] = ipsec[option] for option in remote_options: - if option in ['sharing', 'myid_data']: + if option in ["sharing", "myid_data"]: continue params[option] = remote_options[option] - if 'peerid_type' in params and params['peerid_type'] == 'keyid tag': - params['peerid_data'] = remote_options['myid_data'] + if "peerid_type" in params and params["peerid_type"] == "keyid tag": + params["peerid_data"] = remote_options["myid_data"] - if 'myid_data' in local: - params['myid_data'] = local['myid_data'] + if "myid_data" in local: + params["myid_data"] = local["myid_data"] return res def format_ipsec_aggregate_proposals(all_tunnels, pfname): - """ format proposals for format_ipsec_aggregate """ + """format proposals for format_ipsec_aggregate""" res = list() for name, ipsec in all_tunnels.items(): - pfsenses = ipsec['pfsenses'] + pfsenses = ipsec["pfsenses"] if pfname not in pfsenses: continue - if 'phase1' not in ipsec: + if "phase1" not in ipsec: raise AnsibleFilterError("phase1 is missing in {0}".format(name)) - phase1 = ipsec['phase1'] + phase1 = ipsec["phase1"] p1s = list() - if 'encryptions' not in phase1: - raise AnsibleFilterError("encryptions is missing in phase1 of {0}".format(name)) + if "encryptions" not in phase1: + raise AnsibleFilterError( + "encryptions is missing in phase1 of {0}".format(name) + ) - if 'hashes' not in phase1: + if "hashes" not in phase1: raise AnsibleFilterError("hashes is missing in phase1 of {0}".format(name)) - encryptions = phase1['encryptions'] - hashes = phase1['hashes'].split(' ') + encryptions = phase1["encryptions"] + hashes = phase1["hashes"].split(" ") for remote_name in pfsenses: if remote_name == pfname: @@ -77,14 +80,17 @@ def format_ipsec_aggregate_proposals(all_tunnels, pfname): for hash_option in hashes: params = dict() p1s.append(params) - params['descr'] = name + ' to ' + remote_name - params['state'] = 'present' - params['hash'] = hash_option - params['encryption'] = encryption - if encryptions[encryption] is not None and encryptions[encryption] != 'None': - params['key_length'] = encryptions[encryption] + params["descr"] = name + " to " + remote_name + params["state"] = "present" + params["hash"] = hash_option + params["encryption"] = encryption + if ( + encryptions[encryption] is not None + and encryptions[encryption] != "None" + ): + params["key_length"] = encryptions[encryption] for p1_option in phase1: - if p1_option in ['encryptions', 'hashes']: + if p1_option in ["encryptions", "hashes"]: continue for p1 in p1s: p1[p1_option] = phase1[p1_option] @@ -93,67 +99,71 @@ def format_ipsec_aggregate_proposals(all_tunnels, pfname): def format_ipsec_aggregate_p2s(all_tunnels, pfname): - """ format p2s for format_ipsec_aggregate """ + """format p2s for format_ipsec_aggregate""" res = list() for name, ipsec in all_tunnels.items(): - pfsenses = ipsec['pfsenses'] + pfsenses = ipsec["pfsenses"] if pfname not in pfsenses: continue - if 'phase2' not in ipsec: + if "phase2" not in ipsec: raise AnsibleFilterError("phase2 is missing in {0}".format(name)) - phase2 = ipsec['phase2'] + phase2 = ipsec["phase2"] - if 'mode' not in phase2: + if "mode" not in phase2: raise AnsibleFilterError("mode is missing in phase2 of {0}".format(name)) - mode = phase2['mode'] + mode = phase2["mode"] local = pfsenses[pfname] - if 'sharing' in local: - local_sharing = local['sharing'].split(' ') - elif mode != 'transport': - raise AnsibleFilterError("sharing is missing for {0} in {1}".format(pfname, name)) + if "sharing" in local: + local_sharing = local["sharing"].split(" ") + elif mode != "transport": + raise AnsibleFilterError( + "sharing is missing for {0} in {1}".format(pfname, name) + ) p2s = list() for remote_name, remote in pfsenses.items(): if remote_name == pfname: continue - if 'sharing' in remote: - remote_sharing = remote['sharing'].split(' ') - elif mode != 'transport': - raise AnsibleFilterError("sharing is missing for {0} in {1}".format(remote_name, name)) - - if mode != 'transport': + if "sharing" in remote: + remote_sharing = remote["sharing"].split(" ") + elif mode != "transport": + raise AnsibleFilterError( + "sharing is missing for {0} in {1}".format(remote_name, name) + ) + + if mode != "transport": for local_network in local_sharing: for remote_network in remote_sharing: params = dict() p2s.append(params) - params['p1_descr'] = name + ' to ' + remote_name - params['descr'] = local_network + ' to ' + remote_network - params['state'] = 'present' - params['local'] = local_network - params['remote'] = remote_network + params["p1_descr"] = name + " to " + remote_name + params["descr"] = local_network + " to " + remote_network + params["state"] = "present" + params["local"] = local_network + params["remote"] = remote_network else: params = dict() p2s.append(params) - params['descr'] = name + ' to ' + remote_name - params['p1_descr'] = name + ' to ' + remote_name - params['state'] = 'present' + params["descr"] = name + " to " + remote_name + params["p1_descr"] = name + " to " + remote_name + params["state"] = "present" for p2_option, p2_value in phase2.items(): for p2 in p2s: - if p2_option == 'encryptions': + if p2_option == "encryptions": for encryption, keylength in p2_value.items(): p2[encryption] = True - if keylength is not None and keylength != 'None': + if keylength is not None and keylength != "None": if isinstance(keylength, str): - p2[encryption + '_len'] = keylength + p2[encryption + "_len"] = keylength else: - p2[encryption + '_len'] = str(keylength) - elif p2_option == 'hashes': - hashes = p2_value.split(' ') + p2[encryption + "_len"] = str(keylength) + elif p2_option == "hashes": + hashes = p2_value.split(" ") for hash_option in hashes: p2[hash_option] = True else: @@ -163,27 +173,31 @@ def format_ipsec_aggregate_p2s(all_tunnels, pfname): def format_ipsec_aggregate(*terms): - """ format var for ipsec_aggregate """ + """format var for ipsec_aggregate""" if len(terms) != 2 or not isinstance(terms[0], dict): - raise AnsibleFilterError("format_ipsec_aggregate expects one dictionnary of ipsec tunnels") + raise AnsibleFilterError( + "format_ipsec_aggregate expects one dictionnary of ipsec tunnels" + ) all_tunnels = terms[0] pfname = terms[1] res = dict() - res['aggregated_ipsecs'] = format_ipsec_aggregate_ipsecs(all_tunnels, pfname) - res['aggregated_ipsec_proposals'] = format_ipsec_aggregate_proposals(all_tunnels, pfname) - res['aggregated_ipsec_p2s'] = format_ipsec_aggregate_p2s(all_tunnels, pfname) + res["aggregated_ipsecs"] = format_ipsec_aggregate_ipsecs(all_tunnels, pfname) + res["aggregated_ipsec_proposals"] = format_ipsec_aggregate_proposals( + all_tunnels, pfname + ) + res["aggregated_ipsec_p2s"] = format_ipsec_aggregate_p2s(all_tunnels, pfname) return res class FilterModule(object): - """ FilterModule """ + """FilterModule""" @staticmethod def filters(): - """ defined functions """ + """defined functions""" return { - 'format_ipsec_aggregate': format_ipsec_aggregate, + "format_ipsec_aggregate": format_ipsec_aggregate, } diff --git a/plugins/lookup/pfsense.py b/plugins/lookup/pfsense.py index bf83f336..924fd42c 100644 --- a/plugins/lookup/pfsense.py +++ b/plugins/lookup/pfsense.py @@ -3,7 +3,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type DOCUMENTATION = """ @@ -217,26 +218,54 @@ import ipaddress OPTION_FIELDS = [ - 'gateway', 'log', 'queue', 'ackqueue', 'in_queue', 'out_queue', 'queue_error', 'icmptype', 'filter', 'efilter', 'ifilter', 'sched', 'quick', 'direction', - 'staticnatport', 'ipprotocol', - 'associated_rule', 'natreflection', + "gateway", + "log", + "queue", + "ackqueue", + "in_queue", + "out_queue", + "queue_error", + "icmptype", + "filter", + "efilter", + "ifilter", + "sched", + "quick", + "direction", + "staticnatport", + "ipprotocol", + "associated_rule", + "natreflection", +] +OUTPUT_OPTION_FIELDS = [ + "gateway", + "log", + "queue", + "ackqueue", + "in_queue", + "out_queue", + "queue_error", + "icmptype", + "sched", + "quick", + "direction", + "ipprotocol", ] -OUTPUT_OPTION_FIELDS = ['gateway', 'log', 'queue', 'ackqueue', 'in_queue', 'out_queue', 'queue_error', 'icmptype', 'sched', 'quick', 'direction', 'ipprotocol'] -OUTPUT_SRC_NAT_OPTION_FIELDS = ['staticnatport', 'ipprotocol'] -OUTPUT_DST_NAT_OPTION_FIELDS = ['associated_rule', 'natreflection'] +OUTPUT_SRC_NAT_OPTION_FIELDS = ["staticnatport", "ipprotocol"] +OUTPUT_DST_NAT_OPTION_FIELDS = ["associated_rule", "natreflection"] display = Display() def to_unicode(string): - """ return a unicode representation of string if required """ + """return a unicode representation of string if required""" if sys.version_info[0] >= 3: return string return string.decode("utf-8") def ordered_load(stream, loader_cls=yaml.Loader, object_pairs_hook=OrderedDict): - """ load and return yaml data from stream using ordered dicts """ + """load and return yaml data from stream using ordered dicts""" class OrderedLoader(loader_cls): def __init__(self, stream): @@ -245,7 +274,7 @@ def __init__(self, stream): def include(self, node): filename = os.path.join(self._root, self.construct_scalar(node)) - with open(filename, 'r') as f: + with open(filename, "r") as f: return yaml.load(f, OrderedLoader) def construct_mapping(loader, node): @@ -253,31 +282,32 @@ def construct_mapping(loader, node): return object_pairs_hook(loader.construct_pairs(node)) if DNS_IMPORT_ERROR: - raise AnsibleError('dns must be installed to use ordered_load from this plugin') from DNS_IMPORT_ERROR + raise AnsibleError( + "dns must be installed to use ordered_load from this plugin" + ) from DNS_IMPORT_ERROR OrderedLoader.add_constructor( - yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, - construct_mapping) - OrderedLoader.add_constructor( - '!include', - OrderedLoader.include) + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping + ) + OrderedLoader.add_constructor("!include", OrderedLoader.include) return yaml.load(stream, OrderedLoader) def static_vars(**kwargs): - """ static decorator to declare static vars """ + """static decorator to declare static vars""" def decorate(func): - """ static decorator func """ + """static decorator func""" for k in kwargs: setattr(func, k, kwargs[k]) return func + return decorate @static_vars(res_cache=dict()) def to_ip_address(address): - """ convert address to IPv4Address or IPv6Address """ + """convert address to IPv4Address or IPv6Address""" res = to_ip_address.res_cache.get(address) if res is None: res = ipaddress.ip_address(to_unicode(address)) @@ -287,7 +317,7 @@ def to_ip_address(address): @static_vars(res_cache=dict()) def to_ip_network(address, strict=True): - """ convert address to IPv4Network or IPv6Network """ + """convert address to IPv4Network or IPv6Network""" key = address + str(strict) res = to_ip_network.res_cache.get(key) if res is None: @@ -297,68 +327,84 @@ def to_ip_network(address, strict=True): @static_vars( - classA=ipaddress.IPv4Network((u"10.0.0.0", u"255.0.0.0")), - classB=ipaddress.IPv4Network((u"172.16.0.0", u"255.240.0.0")), - classC=ipaddress.IPv4Network((u"192.168.0.0", u"255.255.0.0")), - res_cache=dict()) + classA=ipaddress.IPv4Network(("10.0.0.0", "255.0.0.0")), + classB=ipaddress.IPv4Network(("172.16.0.0", "255.240.0.0")), + classC=ipaddress.IPv4Network(("192.168.0.0", "255.255.0.0")), + res_cache=dict(), +) def is_private_ip(address): - """ check if ip address is class A, B or C """ + """check if ip address is class A, B or C""" res = is_private_ip.res_cache.get(address) if res is None: if not isinstance(address, ipaddress.IPv4Address): ip_address = to_ip_address(to_unicode(address)) else: ip_address = address - res = ip_address in is_private_ip.classA or ip_address in is_private_ip.classB or ip_address in is_private_ip.classC + res = ( + ip_address in is_private_ip.classA + or ip_address in is_private_ip.classB + or ip_address in is_private_ip.classC + ) is_private_ip.res_cache[address] = res return res @static_vars( - classA=ipaddress.IPv4Network((u"10.0.0.0", u"255.0.0.0")), - classB=ipaddress.IPv4Network((u"172.16.0.0", u"255.240.0.0")), - classC=ipaddress.IPv4Network((u"192.168.0.0", u"255.255.0.0")), - res_cache=dict()) + classA=ipaddress.IPv4Network(("10.0.0.0", "255.0.0.0")), + classB=ipaddress.IPv4Network(("172.16.0.0", "255.240.0.0")), + classC=ipaddress.IPv4Network(("192.168.0.0", "255.255.0.0")), + res_cache=dict(), +) def is_private_network(address): - """ check if network is class A, B or C """ + """check if network is class A, B or C""" res = is_private_network.res_cache.get(address) if res is None: if not isinstance(address, ipaddress.IPv4Network): net = to_ip_network(to_unicode(address)) else: net = address - res = net.subnet_of(is_private_network.classA) or net.subnet_of(is_private_network.classB) or net.subnet_of(is_private_network.classC) + res = ( + net.subnet_of(is_private_network.classA) + or net.subnet_of(is_private_network.classB) + or net.subnet_of(is_private_network.classC) + ) is_private_network.res_cache[address] = res return res @static_vars( - ip_broadcast=ipaddress.IPv4Address((u"255.255.255.255")), - net_multicast=ipaddress.IPv4Network((u"224.0.0.0", u"255.255.255.0")), - res_cache=dict()) + ip_broadcast=ipaddress.IPv4Address(("255.255.255.255")), + net_multicast=ipaddress.IPv4Network(("224.0.0.0", "255.255.255.0")), + res_cache=dict(), +) def is_ip_broadcast(address): - """ check if ip address is ip broadcast address """ + """check if ip address is ip broadcast address""" res = is_ip_broadcast.res_cache.get(address) if res is None: if not isinstance(address, ipaddress.IPv4Address): ip_address = to_ip_address(to_unicode(address)) else: ip_address = address - res = ip_address == is_ip_broadcast.ip_broadcast or ip_address in is_ip_broadcast.net_multicast + res = ( + ip_address == is_ip_broadcast.ip_broadcast + or ip_address in is_ip_broadcast.net_multicast + ) is_ip_broadcast.res_cache[address] = res return res @static_vars(re_cache=None) def is_fqdn(address): - """ check if address is a fqdn address """ + """check if address is a fqdn address""" if is_fqdn.re_cache is None: - is_fqdn.re_cache = re.compile(r'(?=^.{4,253}$)(^((?!-)[a-zA-Z0-9-]{0,62}[a-zA-Z0-9]\.)+[a-zA-Z]{2,63}$)') + is_fqdn.re_cache = re.compile( + r"(?=^.{4,253}$)(^((?!-)[a-zA-Z0-9-]{0,62}[a-zA-Z0-9]\.)+[a-zA-Z]{2,63}$)" + ) return is_fqdn.re_cache.match(address) is not None def resolve_hostname(address, dns_servers=None): - """ get ip for hostname """ + """get ip for hostname""" if dns_servers is None: try: resolved_ip = socket.gethostbyname(address) @@ -368,7 +414,9 @@ def resolve_hostname(address, dns_servers=None): msg = "Unable to resolve: {0}".format(address) else: if DNS_IMPORT_ERROR: - raise AnsibleError('dns must be installed to use ordered_load from this plugin') from DNS_IMPORT_ERROR + raise AnsibleError( + "dns must be installed to use ordered_load from this plugin" + ) from DNS_IMPORT_ERROR error = None try: @@ -380,17 +428,19 @@ def resolve_hostname(address, dns_servers=None): if answers: return answers[0].address except exception.Timeout: - error = 'timeout' + error = "timeout" - msg = "Unable to resolve: {0} using {1} dns servers".format(address, ','.join(dns_servers)) + msg = "Unable to resolve: {0} using {1} dns servers".format( + address, ",".join(dns_servers) + ) if error is not None: - msg += ' ({0})'.format(error) + msg += " ({0})".format(error) raise AssertionError(msg) def is_valid_ip(address): - """ validate ip address format """ + """validate ip address format""" try: to_ip_address(to_unicode(address)) return True @@ -400,7 +450,7 @@ def is_valid_ip(address): def is_valid_port(port): - """ validate port format """ + """validate port format""" if not port.isdigit(): return False @@ -409,8 +459,8 @@ def is_valid_port(port): def is_valid_port_range(port_range): - """ validate port range format """ - group = re.match(r'^(\d+)-(\d+)$', port_range) + """validate port range format""" + group = re.match(r"^(\d+)-(\d+)$", port_range) if not group: return False nport1 = int(group.group(1)) @@ -420,7 +470,7 @@ def is_valid_port_range(port_range): def is_valid_network(address): - """ validate network address format """ + """validate network address format""" try: to_ip_network(to_unicode(address)) return True @@ -430,7 +480,7 @@ def is_valid_network(address): def rule_product_dict(tab, rule, field, out_field=None): - """ Return cartesian product between rule[field] and tab as dicts """ + """Return cartesian product between rule[field] and tab as dicts""" if field not in rule: return tab if not out_field: @@ -446,7 +496,7 @@ def rule_product_dict(tab, rule, field, out_field=None): def rule_product_ports(rule, field, field_port): - """ Return cartesian product between rule[field] and field_port as string """ + """Return cartesian product between rule[field] and field_port as string""" if field_port not in rule: return rule[field] @@ -462,24 +512,29 @@ def rule_product_ports(rule, field, field_port): if not added: ret.append(alias) - return ' '.join(list(ret)) + return " ".join(list(ret)) def get_bool(values, field): - """ Return boolean field value from values """ + """Return boolean field value from values""" field_value = False if isinstance(values[field], bool): field_value = values[field] elif isinstance(values[field], str): - if values[field].lower() in ['yes', 'true']: + if values[field].lower() in ["yes", "true"]: field_value = True - elif values[field].lower() not in ['no', 'false']: - raise AnsibleError('{0} must be yes/no or true/false (got "{1}")'.format(field, values[field])) + elif values[field].lower() not in ["no", "false"]: + raise AnsibleError( + '{0} must be yes/no or true/false (got "{1}")'.format( + field, values[field] + ) + ) return field_value class PFSenseHostAlias(object): - """ Class holding structured pfsense host alias definition """ + """Class holding structured pfsense host alias definition""" + def __init__(self): self.name = None self.descr = None @@ -520,7 +575,9 @@ def copy(self): # networks for network in self.networks: new_network = ipaddress.IPv4Network.__new__(ipaddress.IPv4Network) - new_network.network_address = ipaddress.IPv4Address.__new__(ipaddress.IPv4Address) + new_network.network_address = ipaddress.IPv4Address.__new__( + ipaddress.IPv4Address + ) new_network.network_address._ip = network.network_address._ip new_network.netmask = network.netmask new_network._prefixlen = network._prefixlen @@ -539,10 +596,18 @@ def copy(self): def __str__(self): return "name={0}, descr={1}, definition={2}, ips={3}, networks={4}, local_interfaces={5}, routed_interfaces={6}, fake={7}".format( - self.name, self.descr, self.definition, self.ips, self.networks, self.local_interfaces, self.routed_interfaces, self.fake) + self.name, + self.descr, + self.definition, + self.ips, + self.networks, + self.local_interfaces, + self.routed_interfaces, + self.fake, + ) def compute_any(self, data): - """ Do all computations for object 'any' """ + """Do all computations for object 'any'""" # we add all interfaces of all pfsenses for pfsense in data.pfsenses_obj.values(): for interface in pfsense.interfaces.values(): @@ -550,16 +615,16 @@ def compute_any(self, data): self.routed_interfaces[pfsense.name].append(interface.name) def compute_all(self, data): - """ Do all computations """ + """Do all computations""" if not self._computed: self._computed = True - if self.name != 'any': + if self.name != "any": self.compute_addresses(data) self.compute_local_interfaces(data) self.compute_routed_interfaces(data) def compute_addresses(self, data): - """ Convert all aliases to structured ip addresses or networks """ + """Convert all aliases to structured ip addresses or networks""" todo = [] todo.extend(self.definition) @@ -568,7 +633,7 @@ def compute_addresses(self, data): address = todo.pop() # special case when (self) is used in nat rules - if address == '': + if address == "": continue # it's an ip address @@ -595,7 +660,9 @@ def compute_addresses(self, data): self.ips.append(host_ip) continue - raise AssertionError("Invalid address: " + address + " for " + self.name) + raise AssertionError( + "Invalid address: " + address + " for " + self.name + ) # it's another alias alias = data.hosts_aliases_obj.get(address) @@ -606,10 +673,10 @@ def compute_addresses(self, data): self.networks += alias.networks continue - todo.extend(data.all_aliases[address]['ip'].split()) + todo.extend(data.all_aliases[address]["ip"].split()) def _is_in_networks(self, interface, fcheckname): - """ check if an alias is in a network of an interface """ + """check if an alias is in a network of an interface""" fcheck = getattr(interface, fcheckname) for alias_ip in self.ips: if is_ip_broadcast(alias_ip): @@ -624,19 +691,19 @@ def _is_in_networks(self, interface, fcheckname): return True def is_in_local_network(self, interface): - """ check if an alias is in the local network of an interface """ - return self._is_in_networks(interface, 'local_network_contains') + """check if an alias is in the local network of an interface""" + return self._is_in_networks(interface, "local_network_contains") def is_in_remote_networks(self, interface): - """ check if an alias is in the remote networks of an interface """ - return self._is_in_networks(interface, 'remote_networks_contains') + """check if an alias is in the remote networks of an interface""" + return self._is_in_networks(interface, "remote_networks_contains") def is_in_adjacent_networks(self, interface): - """ check if an alias is in the adjacent networks of an interface """ - return self._is_in_networks(interface, 'adjacent_networks_contains') + """check if an alias is in the adjacent networks of an interface""" + return self._is_in_networks(interface, "adjacent_networks_contains") def compute_routed_interfaces(self, data): - """ Find all interfaces on all pfsense where the alias may be used as a routed source """ + """Find all interfaces on all pfsense where the alias may be used as a routed source""" for pfsense in data.pfsenses_obj.values(): # if the target alias is local, we does not consider the other interfaces self.routed_interfaces[pfsense.name] = set() @@ -653,7 +720,7 @@ def compute_routed_interfaces(self, data): self.routed_interfaces[pfsense.name].update(interfaces) def compute_local_interfaces(self, data): - """ Find all interfaces on all pfsense where the alias may be used as a local source """ + """Find all interfaces on all pfsense where the alias may be used as a local source""" for pfsense in data.pfsenses_obj.values(): self.local_interfaces[pfsense.name] = set() for alias_ip in self.ips: @@ -665,7 +732,7 @@ def compute_local_interfaces(self, data): self.local_interfaces[pfsense.name].update(interfaces) def is_whole_local(self, pfsense): - """ check if all ips/networks match a local network interface in pfense """ + """check if all ips/networks match a local network interface in pfense""" for alias_ip in self.ips: if is_ip_broadcast(alias_ip): continue @@ -680,19 +747,27 @@ def is_whole_local(self, pfsense): return True def routed_by_interfaces(self, pfsense, use_remote_networks=True): - """ return all interfaces for which all ips/networks match a adjacent/remote network in pfense """ + """return all interfaces for which all ips/networks match a adjacent/remote network in pfense""" all_interfaces = set() # we always search threw all interfaces to handle cases with internet for alias_ip in self.ips: - interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains(alias_ip, use_remote_networks=True) - interfaces = pfsense.hack_internet_routing(interfaces, alias_ip, use_remote_networks=True) + interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains( + alias_ip, use_remote_networks=True + ) + interfaces = pfsense.hack_internet_routing( + interfaces, alias_ip, use_remote_networks=True + ) all_interfaces.update(interfaces) for alias_net in self.networks: - interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains(alias_net, use_remote_networks=True) - interfaces = pfsense.hack_internet_routing(interfaces, alias_net, use_remote_networks=True) + interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains( + alias_net, use_remote_networks=True + ) + interfaces = pfsense.hack_internet_routing( + interfaces, alias_net, use_remote_networks=True + ) all_interfaces.update(interfaces) # if we didn't want the remote interfaces, we only keep the adjacents @@ -706,21 +781,21 @@ def routed_by_interfaces(self, pfsense, use_remote_networks=True): return all_interfaces def is_adjacent_or_remote(self, pfsense): - """ check if all ips/networks are in a adjacent/remote network in pfense """ + """check if all ips/networks are in a adjacent/remote network in pfense""" return len(self.routed_by_interfaces(pfsense, use_remote_networks=True)) > 0 def is_adjacent(self, pfsense): - """ check if all ips/networks are in a adjacent network in pfense """ + """check if all ips/networks are in a adjacent network in pfense""" return len(self.routed_by_interfaces(pfsense, use_remote_networks=False)) > 0 def is_ip_broadcast(self): - """ check if an alias is the ip_broadcast """ + """check if an alias is the ip_broadcast""" if len(self.ips) != 1 or self.networks: return False return is_ip_broadcast(self.ips[0]) def is_whole_in_pfsense(self, pfsense): - """ check if all ips/networks have as least one interface in pfense """ + """check if all ips/networks have as least one interface in pfense""" if self.name in pfsense.is_whole_in_pfsense_cache: return pfsense.is_whole_in_pfsense_cache[self.name] @@ -742,7 +817,7 @@ def is_whole_in_pfsense(self, pfsense): return True def is_whole_not_in_pfsense(self, pfsense): - """ check if all ips/networks have as least one interface in pfense """ + """check if all ips/networks have as least one interface in pfense""" if self.name in pfsense.is_whole_not_in_pfsense_cache: return pfsense.is_whole_not_in_pfsense_cache[self.name] @@ -763,7 +838,7 @@ def is_whole_not_in_pfsense(self, pfsense): return True def is_whole_in_same_routing_ifaces(self, pfsense): - """ check if all ips/networks have the same interfaces in pfense """ + """check if all ips/networks have the same interfaces in pfense""" if self.name in pfsense.is_whole_in_same_routing_ifaces_cache: return pfsense.is_whole_in_same_routing_ifaces_cache[self.name] @@ -774,7 +849,9 @@ def is_whole_in_same_routing_ifaces(self, pfsense): target_local_interfaces = None for alias_ip in self.ips: - interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains(alias_ip) + interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains( + alias_ip + ) interfaces = pfsense.hack_internet_routing(interfaces, alias_ip) if target_ar_interfaces is None: target_ar_interfaces = interfaces @@ -791,7 +868,9 @@ def is_whole_in_same_routing_ifaces(self, pfsense): return False for alias_net in self.networks: - interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains(alias_net) + interfaces = pfsense.interfaces_adjacent_or_remote_networks_contains( + alias_net + ) interfaces = pfsense.hack_internet_routing(interfaces, alias_net) if target_ar_interfaces is None: target_ar_interfaces = interfaces @@ -811,7 +890,7 @@ def is_whole_in_same_routing_ifaces(self, pfsense): return True def match_local_interface_ip(self, pfsense): - """ Return True if the alias ip match one interface on the pfsense """ + """Return True if the alias ip match one interface on the pfsense""" for alias_ip in self.ips: for iface in pfsense.interfaces.values(): for local_ip in iface.local_ips: @@ -821,7 +900,8 @@ def match_local_interface_ip(self, pfsense): class PFSenseRule(object): - """ Class holding structured pfsense rule declaration """ + """Class holding structured pfsense rule declaration""" + def __init__(self): self.name = None self.separator = None @@ -885,7 +965,7 @@ def copy(self): return copy_object def get_option(self, name): - """ return option value for name """ + """return option value for name""" if name in self.options: return self.options[name] separator = self.separator @@ -896,7 +976,7 @@ def get_option(self, name): return None def to_json(self): - """ return JSON String containing rule """ + """return JSON String containing rule""" srcs = [] for src in self.src: srcs.append(src.name) @@ -922,14 +1002,15 @@ def to_json(self): for field in OUTPUT_OPTION_FIELDS: value = self.get_option(field) if value is not None: - res += ', {0}: {1}'.format(field, value) + res += ", {0}: {1}".format(field, value) res += " }" return res class PFSenseRuleSeparator(object): - """ Class holding structured pfsense rule separator declaration """ + """Class holding structured pfsense rule separator declaration""" + def __init__(self): self.name = None self.interface = None @@ -940,15 +1021,20 @@ def __hash__(self): return hash(self.name + self.interface) def __eq__(self, other): - return self.__class__ == other.__class__ and self.name == other.name and self.interface == other.interface + return ( + self.__class__ == other.__class__ + and self.name == other.name + and self.interface == other.interface + ) class PFSenseInterface(object): - """ Class holding structured pfsense interface definition """ + """Class holding structured pfsense interface definition""" + def __init__(self): self.name = None - self.local_ip = None # first ip defined - self.local_network = None # first network defined + self.local_ip = None # first ip defined + self.local_network = None # first network defined self.local_ips = set() self.local_networks = set() self.remote_networks = set() @@ -960,27 +1046,39 @@ def __init__(self): @staticmethod def _networks_contains(address, networks): - """ return true if address is into networks """ + """return true if address is into networks""" if isinstance(address, ipaddress.IPv4Address): private_address = is_private_ip(address) for snet in networks: private_net = is_private_network(snet) - if private_address and private_net or not private_address and not private_net: + if ( + private_address + and private_net + or not private_address + and not private_net + ): if address in snet: return True elif isinstance(address, ipaddress.IPv4Network): private_address = is_private_network(address) for snet in networks: private_net = is_private_network(snet) - if private_address and private_net or not private_address and not private_net: + if ( + private_address + and private_net + or not private_address + and not private_net + ): if address.subnet_of(snet): return True else: - raise AssertionError('wrong type in remote_networks_contains:' + type(address)) + raise AssertionError( + "wrong type in remote_networks_contains:" + type(address) + ) return False def remote_networks_contains(self, address): - """ return true if address is defined in remote_networks of this interface """ + """return true if address is defined in remote_networks of this interface""" res = self._remote_networks_contains_cache.get(address) if res is None: res = self._networks_contains(address, self.remote_networks) @@ -988,7 +1086,7 @@ def remote_networks_contains(self, address): return res def adjacent_networks_contains(self, address): - """ return true if address is defined in adjacent_networks of this interface """ + """return true if address is defined in adjacent_networks of this interface""" res = self._adjacent_networks_contains_cache.get(address) if res is None: res = self._networks_contains(address, self.adjacent_networks) @@ -996,27 +1094,40 @@ def adjacent_networks_contains(self, address): return res def local_network_contains(self, address): - """ return true if address is in the local network of this interface """ + """return true if address is in the local network of this interface""" if self.local_networks: for local_network in self.local_networks: if isinstance(address, ipaddress.IPv4Address): private_address = is_private_ip(address) private_net = is_private_network(local_network) - if private_address and private_net or not private_address and not private_net: + if ( + private_address + and private_net + or not private_address + and not private_net + ): if address in local_network: return True elif isinstance(address, ipaddress.IPv4Network): private_address = is_private_network(address) private_net = is_private_network(local_network) - if private_address and private_net or not private_address and not private_net: + if ( + private_address + and private_net + or not private_address + and not private_net + ): if address.subnet_of(local_network): return True else: - raise AssertionError('wrong type in local_network_contains:' + type(address)) + raise AssertionError( + "wrong type in local_network_contains:" + type(address) + ) return False def are_in_same_network(self, src, dst): - """ return true if both the aliases are in the same network on the interface """ + """return true if both the aliases are in the same network on the interface""" + def _match(snet, alias): for ip in alias.ips: if ip not in snet: @@ -1033,7 +1144,8 @@ def _match(snet, alias): class PFSense(object): - """ Class holding structured pfsense definition """ + """Class holding structured pfsense definition""" + def __init__(self, name, interfaces): self.name = name self.interfaces = interfaces @@ -1046,23 +1158,27 @@ def __init__(self, name, interfaces): self._hack_internet_routing_cache = dict() def any_adjacent_networks_contains(self, address): - """ return true if address is defined in adjacent_networks of any interface """ + """return true if address is defined in adjacent_networks of any interface""" return len(self.interfaces_adjacent_networks_contains(address)) != 0 def any_remote_networks_contains(self, address): - """ return true if address is defined in remote_networks of any interface """ + """return true if address is defined in remote_networks of any interface""" return len(self.interfaces_remote_networks_contains(address)) != 0 def any_local_network_contains(self, address): - """ return true if address is defined in local network of any interface """ + """return true if address is defined in local network of any interface""" return len(self.interfaces_local_networks_contains(address)) != 0 def any_network_contains(self, address): - """ return true if address is defined in the local, remote or adjacent networks of any interface """ - return self.any_local_network_contains(address) or self.any_remote_networks_contains(address) or self.any_adjacent_networks_contains(address) + """return true if address is defined in the local, remote or adjacent networks of any interface""" + return ( + self.any_local_network_contains(address) + or self.any_remote_networks_contains(address) + or self.any_adjacent_networks_contains(address) + ) def _interfaces_network_contains(self, address, networks_name): - """ return interfaces names where address is in the interface network """ + """return interfaces names where address is in the interface network""" res = set() if isinstance(address, ipaddress.IPv4Address): private_address = is_private_ip(address) @@ -1073,7 +1189,12 @@ def _interfaces_network_contains(self, address, networks_name): networks = [networks] for snet in networks: private_net = is_private_network(snet) - if private_address and private_net or not private_address and not private_net: + if ( + private_address + and private_net + or not private_address + and not private_net + ): if address in snet: res.add(interface.name) elif isinstance(address, ipaddress.IPv4Network): @@ -1085,49 +1206,58 @@ def _interfaces_network_contains(self, address, networks_name): networks = [networks] for snet in networks: private_net = is_private_network(snet) - if private_address and private_net or not private_address and not private_net: + if ( + private_address + and private_net + or not private_address + and not private_net + ): if address.subnet_of(snet): res.add(interface.name) else: - raise AssertionError('wrong type in _interfaces_network_contains:' + type(address)) + raise AssertionError( + "wrong type in _interfaces_network_contains:" + type(address) + ) return res def interfaces_local_networks_contains(self, address): - """ return interfaces names where address is in the interface local network """ + """return interfaces names where address is in the interface local network""" res = self._interfaces_local_networks_contains_cache.get(address) if res is None: - res = self._interfaces_network_contains(address, 'local_networks') + res = self._interfaces_network_contains(address, "local_networks") self._interfaces_local_networks_contains_cache[address] = res return copy(res) def interfaces_remote_networks_contains(self, address): - """ return interfaces names where address is in the interface remote networks """ + """return interfaces names where address is in the interface remote networks""" res = self._interfaces_remote_networks_contains_cache.get(address) if res is None: - res = self._interfaces_network_contains(address, 'remote_networks') + res = self._interfaces_network_contains(address, "remote_networks") self._interfaces_remote_networks_contains_cache[address] = res return copy(res) def interfaces_adjacent_networks_contains(self, address): - """ return interfaces names where address is in the interface adjacent networks """ + """return interfaces names where address is in the interface adjacent networks""" res = self._interfaces_adjacent_networks_contains_cache.get(address) if res is None: - res = self._interfaces_network_contains(address, 'adjacent_networks') + res = self._interfaces_network_contains(address, "adjacent_networks") self._interfaces_adjacent_networks_contains_cache[address] = res return copy(res) - def interfaces_adjacent_or_remote_networks_contains(self, address, use_remote_networks=True): - """ return interfaces names where address are in the interface local or remote networks """ + def interfaces_adjacent_or_remote_networks_contains( + self, address, use_remote_networks=True + ): + """return interfaces names where address are in the interface local or remote networks""" res = self.interfaces_adjacent_networks_contains(address) if use_remote_networks: res.update(self.interfaces_remote_networks_contains(address)) return res - @static_vars(internet=ipaddress.IPv4Network((u"0.0.0.0", u"0.0.0.0"))) + @static_vars(internet=ipaddress.IPv4Network(("0.0.0.0", "0.0.0.0"))) def hack_internet_routing(self, interfaces, address, use_remote_networks=True): - """ internet (defined as route to 0.0.0.0/0) is an issue to automaticly detect interfaces on which routing is done because every host or network match. - if multiple interfaces can be used and there is at least one specific route which match the address, - we consider the internet ones as mistakes and remove them """ + """internet (defined as route to 0.0.0.0/0) is an issue to automaticly detect interfaces on which routing is done because every host or network match. + if multiple interfaces can be used and there is at least one specific route which match the address, + we consider the internet ones as mistakes and remove them""" key = str(interfaces) + str(address) + str(use_remote_networks) res = self._hack_internet_routing_cache.get(key) if res is None: @@ -1141,13 +1271,23 @@ def hack_internet_routing(self, interfaces, address, use_remote_networks=True): for network in self.interfaces[interface].remote_networks: if network == self.hack_internet_routing.internet: internet_found = True - elif is_net and address.overlaps(network) or not is_net and address in network: + elif ( + is_net + and address.overlaps(network) + or not is_net + and address in network + ): route_found = True for network in self.interfaces[interface].adjacent_networks: if network == self.hack_internet_routing.internet: internet_found = True - elif is_net and address.overlaps(network) or not is_net and address in network: + elif ( + is_net + and address.overlaps(network) + or not is_net + and address in network + ): route_found = True if route_found: @@ -1164,9 +1304,19 @@ def hack_internet_routing(self, interfaces, address, use_remote_networks=True): class PFSenseData(object): - """ Class holding all data """ - - def __init__(self, hosts_aliases, ports_aliases, pfsenses, rules, target_name, gendiff=False, debug=None, aggregate=True): + """Class holding all data""" + + def __init__( + self, + hosts_aliases, + ports_aliases, + pfsenses, + rules, + target_name, + gendiff=False, + debug=None, + aggregate=True, + ): self._hosts_aliases = hosts_aliases self._ports_aliases = ports_aliases self._pfsenses = pfsenses @@ -1190,94 +1340,94 @@ def __init__(self, hosts_aliases, ports_aliases, pfsenses, rules, target_name, g @property def all_aliases(self): - """ all_aliases getter """ + """all_aliases getter""" return self._all_aliases @property def hosts_aliases(self): - """ hosts_aliases getter """ + """hosts_aliases getter""" return self._hosts_aliases @property def ignored_aliases(self): - """ ignored_aliases getter """ + """ignored_aliases getter""" return self._ignored_aliases @property def hosts_aliases_obj(self): - """ hosts_aliases_obj getter """ + """hosts_aliases_obj getter""" return self._hosts_aliases_obj @property def ports_aliases(self): - """ ports_aliases getter """ + """ports_aliases getter""" return self._ports_aliases @property def pfsenses(self): - """ pfsenses getter """ + """pfsenses getter""" return self._pfsenses @property def pfsenses_obj(self): - """ pfsenses_obj getter """ + """pfsenses_obj getter""" return self._pfsenses_obj @property def rules_obj(self): - """ rules_obj getter """ + """rules_obj getter""" return self._rules_obj @property def rules(self): - """ rules getter """ + """rules getter""" return self._rules @property def ignored_rules(self): - """ ignored_rules getter """ + """ignored_rules getter""" return self._ignored_rules @property def rules_separators(self): - """ rules_separators getter """ + """rules_separators getter""" return self._rules_separators @property def target_name(self): - """ target_name getter """ + """target_name getter""" return self._target_name @property def target(self): - """ target getter """ + """target getter""" return self._target @target.setter def target(self, target): - """ target setter """ + """target setter""" self._target = target @property def errors(self): - """ errors getter """ + """errors getter""" return self._errors def set_error(self, error): - """ add an error """ + """add an error""" display.error(error) self._errors.append(error) @staticmethod def is_child_def(values): - """ check if values contains more definitions """ + """check if values contains more definitions""" for value in values.values(): if isinstance(value, (OrderedDict, dict, list)): return False return True def unalias_ip(self, alias): - """ expand alias to it's ip definition """ + """expand alias to it's ip definition""" ret = [] todo = [] todo.extend(alias.split()) @@ -1285,15 +1435,15 @@ def unalias_ip(self, alias): while todo: elts = todo.pop() if elts in self._all_aliases: - todo.extend(self._all_aliases[elts]['ip'].split()) + todo.extend(self._all_aliases[elts]["ip"].split()) else: ret.append(elts) return ret def get_hosts_alias(self, hosts, ips, networks, _basename): - """ return an alias with all the hosts - create it if required """ + """return an alias with all the hosts + create it if required""" searched = ips.union(networks) for alias in self._hosts_aliases_obj.values(): @@ -1312,7 +1462,7 @@ def get_hosts_alias(self, hosts, ips, networks, _basename): basename = _basename[0:26] idx = 1 while True: - obj.name = 'h_{0}_{1}'.format(basename, idx) + obj.name = "h_{0}_{1}".format(basename, idx) if obj.name not in self._all_aliases: break idx = idx + 1 @@ -1320,18 +1470,18 @@ def get_hosts_alias(self, hosts, ips, networks, _basename): self._hosts_aliases_obj[obj.name] = obj alias = dict() - alias['ip'] = ' '.join(obj.definition) - alias['type'] = 'network' + alias["ip"] = " ".join(obj.definition) + alias["type"] = "network" self._all_aliases[obj.name] = alias return obj def get_ports_alias(self, ports, _basename): - """ return an alias with all the ports - create it if required """ + """return an alias with all the ports + create it if required""" for name, alias in self._ports_aliases.items(): - alias_ports = set(alias['port'].split()) + alias_ports = set(alias["port"].split()) if not alias_ports ^ ports: return name @@ -1339,16 +1489,16 @@ def get_ports_alias(self, ports, _basename): basename = _basename[0:26] idx = 1 while True: - name = 'p_{0}_{1}'.format(basename, idx) + name = "p_{0}_{1}".format(basename, idx) if name not in self._all_aliases: break idx = idx + 1 alias = dict() - alias['descr'] = name + alias["descr"] = name sorted_ports = list(ports) sorted_ports.sort() - alias['port'] = ' '.join(sorted_ports) + alias["port"] = " ".join(sorted_ports) self._all_aliases[name] = alias self._ports_aliases[name] = alias @@ -1356,25 +1506,32 @@ def get_ports_alias(self, ports, _basename): class PFSenseDataParser(object): - """ Class doing all data checks and pfsense objects generation """ + """Class doing all data checks and pfsense objects generation""" def __init__(self, data): self._data = data @staticmethod def check_alias_name(name): - """ check an alias name """ + """check an alias name""" # todo: check reserved keywords (any, self, ...) - if re.match('^[a-zA-Z0-9_]+$', name) is None: - raise AnsibleError(name + ': the name of the alias may only consist of the characters "a-z, A-Z, 0-9 and _"') - - def parse_host_alias(self, obj, src_name, type_name, name, allow_any, dns_servers=None): - """ Parse an host alias definition """ + if re.match("^[a-zA-Z0-9_]+$", name) is None: + raise AnsibleError( + name + + ': the name of the alias may only consist of the characters "a-z, A-Z, 0-9 and _"' + ) + + def parse_host_alias( + self, obj, src_name, type_name, name, allow_any, dns_servers=None + ): + """Parse an host alias definition""" ret = True value = obj[src_name] values = str(value).split() if not values: - self._data.set_error("Empty " + src_name + " field for " + type_name + " " + name) + self._data.set_error( + "Empty " + src_name + " field for " + type_name + " " + name + ) return False # we check that all exists @@ -1385,294 +1542,366 @@ def parse_host_alias(self, obj, src_name, type_name, name, allow_any, dns_server for value in values: if is_valid_ip(value): if value not in self._data.hosts_aliases_obj: - self._data.hosts_aliases_obj[value] = self.create_obj_host_alias(value) + self._data.hosts_aliases_obj[value] = self.create_obj_host_alias( + value + ) ip_defs = ip_defs + 1 continue if is_valid_network(value): if value not in self._data.hosts_aliases_obj: - self._data.hosts_aliases_obj[value] = self.create_obj_host_alias(value) + self._data.hosts_aliases_obj[value] = self.create_obj_host_alias( + value + ) net_defs = net_defs + 1 continue - if value not in self._data.hosts_aliases and (value != 'any' or not allow_any): + if value not in self._data.hosts_aliases and ( + value != "any" or not allow_any + ): if is_fqdn(value): if value not in self._data.hosts_aliases_obj: - self._data.hosts_aliases_obj[value] = self.create_obj_host_alias(value, dns_servers) + self._data.hosts_aliases_obj[value] = ( + self.create_obj_host_alias(value, dns_servers) + ) fqdn_defs = fqdn_defs + 1 continue - self._data.set_error(value + " is not a valid alias, ip address or network in " + type_name + " " + name) + self._data.set_error( + value + + " is not a valid alias, ip address or network in " + + type_name + + " " + + name + ) ret = False other_defs = other_defs + 1 if fqdn_defs and (ip_defs + net_defs + other_defs) > 0: - self._data.set_error("fqdn definitions can't be mixed with aliases, IP or networks addresses (in " + type_name + " " + name + ")") + self._data.set_error( + "fqdn definitions can't be mixed with aliases, IP or networks addresses (in " + + type_name + + " " + + name + + ")" + ) ret = False # if it's a real alias, we must check for mixed network definitions if not allow_any: if net_defs > 0: if net_defs != len(values): - self._data.set_error("mixed network definitions and aliases or IP addresses in " + type_name + " " + name) + self._data.set_error( + "mixed network definitions and aliases or IP addresses in " + + type_name + + " " + + name + ) ret = False else: - obj['type'] = 'network' + obj["type"] = "network" else: - obj['type'] = 'host' + obj["type"] = "host" return ret def parse_hosts_aliases(self): - """ Parse all hosts aliases definitions """ + """Parse all hosts aliases definitions""" dups = {} ret = True for name, alias in self._data.hosts_aliases.items(): self.check_alias_name(name) - if 'ignored' in alias and get_bool(alias, 'ignored'): + if "ignored" in alias and get_bool(alias, "ignored"): self._data.ignored_aliases.add(name) continue # ip field is mandatory - if 'ip' not in alias and 'host' not in alias: + if "ip" not in alias and "host" not in alias: self._data.set_error("No ip or host field for alias " + name) ret = False continue # we check that all fields are valid for field in alias: - if field not in ['ip', 'host', 'descr', 'dns', 'ignore_dup', 'ignored']: - self._data.set_error(field + " is not a valid field name in alias " + name) + if field not in ["ip", "host", "descr", "dns", "ignore_dup", "ignored"]: + self._data.set_error( + field + " is not a valid field name in alias " + name + ) ret = False dns_servers = None - if 'dns' in alias: - dns_servers = alias['dns'].split() + if "dns" in alias: + dns_servers = alias["dns"].split() # we check that all ip exist and are not empty - if not self.parse_host_alias(alias, 'ip', 'alias', name, False, dns_servers): + if not self.parse_host_alias( + alias, "ip", "alias", name, False, dns_servers + ): ret = False continue # we check for duplicates _alias = deepcopy(alias) - if 'descr' in _alias: - del _alias['descr'] + if "descr" in _alias: + del _alias["descr"] dup = json.dumps(_alias) if dup in dups: - display.warning("duplicate alias definition for ip " + alias['ip'] + " (" + dups[dup] + ", " + name + ")") - elif 'ignore_dup' not in alias: + display.warning( + "duplicate alias definition for ip " + + alias["ip"] + + " (" + + dups[dup] + + ", " + + name + + ")" + ) + elif "ignore_dup" not in alias: dups[dup] = name obj = PFSenseHostAlias() obj.name = name - obj.definition = alias['ip'].split() - if 'descr' in alias: - obj.descr = alias['descr'] + obj.definition = alias["ip"].split() + if "descr" in alias: + obj.descr = alias["descr"] obj.dns = dns_servers self._data.hosts_aliases_obj[obj.name] = obj return ret def check_port_alias(self, ports, src_name, type_name, name): - """ Checking a port alias definition """ + """Checking a port alias definition""" ret = True values = str(ports).split() - if src_name == 'dst_nat_port': - if '-' in ports: - self._data.set_error("There must be only one port in dst_nat_port of " + name) + if src_name == "dst_nat_port": + if "-" in ports: + self._data.set_error( + "There must be only one port in dst_nat_port of " + name + ) return False if len(values) > 1: - self._data.set_error("There must be only one port in {0} of {1}".format(src_name, name)) + self._data.set_error( + "There must be only one port in {0} of {1}".format(src_name, name) + ) return False if not values: - self._data.set_error("Empty " + src_name + " field for " + type_name + " " + name) + self._data.set_error( + "Empty " + src_name + " field for " + type_name + " " + name + ) return False # we check that all exists for value in values: - if not is_valid_port(value) and not is_valid_port_range(value) and value not in self._data.ports_aliases: - self._data.set_error(value + " is not a valid alias, port or port range in " + type_name + " " + name) + if ( + not is_valid_port(value) + and not is_valid_port_range(value) + and value not in self._data.ports_aliases + ): + self._data.set_error( + value + + " is not a valid alias, port or port range in " + + type_name + + " " + + name + ) ret = False return ret def parse_ports_aliases(self): - """ Checking all ports alias definitions """ + """Checking all ports alias definitions""" dups = {} ret = True for name, alias in self._data.ports_aliases.items(): self.check_alias_name(name) # port field is mandatory - if 'port' not in alias: + if "port" not in alias: self._data.set_error("No port field for alias " + name) ret = False continue - if not isinstance(alias['port'], str): - alias['port'] = str(alias['port']) + if not isinstance(alias["port"], str): + alias["port"] = str(alias["port"]) # we check that all ip exist and are not empty - if not self.check_port_alias(alias['port'], 'port', 'alias', name): + if not self.check_port_alias(alias["port"], "port", "alias", name): ret = False continue # we check that all fields are valid for field in alias: - if field != 'port' and field != 'descr': - self._data.set_error(field + " is not a valid field name in alias " + name) + if field != "port" and field != "descr": + self._data.set_error( + field + " is not a valid field name in alias " + name + ) ret = False # we check for duplicates _alias = deepcopy(alias) - if 'descr' in _alias: - del _alias['descr'] + if "descr" in _alias: + del _alias["descr"] dup = json.dumps(_alias) if dup in dups: - display.warning("duplicate alias definition for port " + alias['port'] + " (" + dups[dup] + ", " + name + ")") + display.warning( + "duplicate alias definition for port " + + alias["port"] + + " (" + + dups[dup] + + ", " + + name + + ")" + ) else: dups[dup] = name return ret def create_obj_any_alias(self): - """ Create a PFSenseHostAlias object for address any (for easier processing later) """ + """Create a PFSenseHostAlias object for address any (for easier processing later)""" obj = PFSenseHostAlias() - obj.name = 'any' - obj.definition = ['any'] + obj.name = "any" + obj.definition = ["any"] obj.fake = True obj.compute_any(self._data) - self._data.all_aliases['any'] = {} - self._data.all_aliases['any']['ip'] = '0.0.0.0/0' - self._data.all_aliases['any']['type'] = 'network' + self._data.all_aliases["any"] = {} + self._data.all_aliases["any"]["ip"] = "0.0.0.0/0" + self._data.all_aliases["any"]["type"] = "network" return obj def create_obj_host_alias(self, src, dns_servers=None): - """ Create a PFSenseHostAlias object from address (for easier processing later) """ + """Create a PFSenseHostAlias object from address (for easier processing later)""" obj = PFSenseHostAlias() obj.name = src obj.definition = [src] obj.fake = True obj.dns = dns_servers - if src == 'any': + if src == "any": return self.create_obj_any_alias() return obj def create_obj_rule_from_def(self, name, rule, separator): - """ Create a PFSenseRule object from yaml definition """ + """Create a PFSenseRule object from yaml definition""" + def _get_bool(field): field_value = False if isinstance(rule[field], bool): field_value = rule[field] elif isinstance(rule[field], str): - if rule[field].lower() in ['yes', 'true']: + if rule[field].lower() in ["yes", "true"]: field_value = True - elif rule[field].lower() not in ['no', 'false']: - self._data.set_error('{0} must be yes/no or true/false (got "{1}")'.format(field, rule[field])) + elif rule[field].lower() not in ["no", "false"]: + self._data.set_error( + '{0} must be yes/no or true/false (got "{1}")'.format( + field, rule[field] + ) + ) return field_value obj = PFSenseRule() obj.name = name obj.separator = separator - if 'src_port' in rule: - if not isinstance(rule['src_port'], str): - obj.src_port = str(rule['src_port']) + if "src_port" in rule: + if not isinstance(rule["src_port"], str): + obj.src_port = str(rule["src_port"]) else: - obj.src_port = rule['src_port'].split() + obj.src_port = rule["src_port"].split() - if 'dst_port' in rule: - if not isinstance(rule['dst_port'], str): - obj.dst_port = str(rule['dst_port']) + if "dst_port" in rule: + if not isinstance(rule["dst_port"], str): + obj.dst_port = str(rule["dst_port"]) else: - obj.dst_port = rule['dst_port'].split() + obj.dst_port = rule["dst_port"].split() - if 'dst_nat_port' in rule: - if not isinstance(rule['dst_nat_port'], str): - obj.dst_nat_port = str(rule['dst_nat_port']) + if "dst_nat_port" in rule: + if not isinstance(rule["dst_nat_port"], str): + obj.dst_nat_port = str(rule["dst_nat_port"]) else: - obj.dst_nat_port = rule['dst_nat_port'].split() + obj.dst_nat_port = rule["dst_nat_port"].split() - if 'protocol' in rule: - obj.protocol = rule['protocol'].split() + if "protocol" in rule: + obj.protocol = rule["protocol"].split() - if 'action' in rule: - obj.action = rule['action'] + if "action" in rule: + obj.action = rule["action"] for field in OPTION_FIELDS: if field in rule: obj.options[field] = rule[field] - if 'force' in rule: - obj.force = _get_bool('force') - if obj.force and not (obj.get_option('filter') and obj.get_option('ifilter')): - self._data.set_error('force must not be used without filter and ifilter') + if "force" in rule: + obj.force = _get_bool("force") + if obj.force and not ( + obj.get_option("filter") and obj.get_option("ifilter") + ): + self._data.set_error( + "force must not be used without filter and ifilter" + ) - if 'floating' in rule: - obj.floating = _get_bool('floating') + if "floating" in rule: + obj.floating = _get_bool("floating") - if 'quick' in rule and not obj.floating: - self._data.set_error('Quick must only be used with floating rules') + if "quick" in rule and not obj.floating: + self._data.set_error("Quick must only be used with floating rules") - for src in rule['src'].split(): + for src in rule["src"].split(): if src not in self._data.hosts_aliases_obj: self._data.hosts_aliases_obj[src] = self.create_obj_host_alias(src) target = self._data.hosts_aliases_obj[src] obj.src.append(target) - for dst in rule['dst'].split(): + for dst in rule["dst"].split(): if dst not in self._data.hosts_aliases_obj: self._data.hosts_aliases_obj[dst] = self.create_obj_host_alias(dst) target = self._data.hosts_aliases_obj[dst] obj.dst.append(target) - if 'src_nat' in rule: - src = rule['src_nat'] + if "src_nat" in rule: + src = rule["src_nat"] if src not in self._data.hosts_aliases_obj: self._data.hosts_aliases_obj[src] = self.create_obj_host_alias(src) target = self._data.hosts_aliases_obj[src] obj.src_nat.append(target) - if 'dst_nat' in rule: - dst = rule['dst_nat'] + if "dst_nat" in rule: + dst = rule["dst_nat"] if dst not in self._data.hosts_aliases_obj: self._data.hosts_aliases_obj[dst] = self.create_obj_host_alias(dst) target = self._data.hosts_aliases_obj[dst] obj.dst_nat.append(target) - if 'asymmetric' in rule: - obj.asymmetric = _get_bool('asymmetric') + if "asymmetric" in rule: + obj.asymmetric = _get_bool("asymmetric") - if 'invert_src' in rule: - obj.invert_src = _get_bool('invert_src') + if "invert_src" in rule: + obj.invert_src = _get_bool("invert_src") if not obj.force: - self._data.set_error('invert_src must be used with force (for now)') + self._data.set_error("invert_src must be used with force (for now)") - if 'invert_dst' in rule: - obj.invert_dst = _get_bool('invert_dst') + if "invert_dst" in rule: + obj.invert_dst = _get_bool("invert_dst") if not obj.force: - self._data.set_error('invert_dst must be used with force (for now)') + self._data.set_error("invert_dst must be used with force (for now)") - if 'invert_src_nat' in rule: - obj.invert_src_nat = _get_bool('invert_src_nat') + if "invert_src_nat" in rule: + obj.invert_src_nat = _get_bool("invert_src_nat") if not obj.force: - self._data.set_error('invert_src_nat must be used with force (for now)') + self._data.set_error("invert_src_nat must be used with force (for now)") - if 'invert_dst_nat' in rule: - obj.invert_dst_nat = _get_bool('invert_dst_nat') + if "invert_dst_nat" in rule: + obj.invert_dst_nat = _get_bool("invert_dst_nat") if not obj.force: - self._data.set_error('invert_dst_nat must be used with force (for now)') + self._data.set_error("invert_dst_nat must be used with force (for now)") return obj def parse_rules(self, parent=None, parent_separator=None): - """ Parse all rules definitions """ + """Parse all rules definitions""" ret = True if parent is None: parent = self._data.rules @@ -1689,83 +1918,126 @@ def parse_rules(self, parent=None, parent_separator=None): if parent_separator.name is None or not parent_separator.name: separator.name = name else: - separator.name = parent_separator.name + ' - ' + name + separator.name = parent_separator.name + " - " + name self._data.rules_separators.append(separator) if not self.parse_rules(rule, separator): ret = False continue - if name == 'options': + if name == "options": parent_separator.options = rule - if parent_separator.options and parent_separator.options.get('invisible'): + if parent_separator.options and parent_separator.options.get( + "invisible" + ): if parent_separator.name is None or not parent_separator.name: - parent_separator.name = '' + parent_separator.name = "" else: parent_separator.name = parent_separator.parent.name continue - if 'ignored' in rule and get_bool(rule, 'ignored'): + if "ignored" in rule and get_bool(rule, "ignored"): self._data.ignored_rules.add(name) continue - for field in ['src', 'dst']: + for field in ["src", "dst"]: # src and dst field are mandatory if field not in rule: - self._data.set_error("No {0} field for rule {1}".format(field, name)) + self._data.set_error( + "No {0} field for rule {1}".format(field, name) + ) ret = False continue # we check that all exist and are not empty - if not self.parse_host_alias(rule, field, 'rule', name, True): + if not self.parse_host_alias(rule, field, "rule", name, True): ret = False - for field in ['src_nat', 'dst_nat']: + for field in ["src_nat", "dst_nat"]: if field not in rule: continue if len(rule[field].split()) > 1: - self._data.set_error('There must be only one address in {0} field of {1}'.format(field, name)) + self._data.set_error( + "There must be only one address in {0} field of {1}".format( + field, name + ) + ) ret = False continue - if field == 'src_nat' and rule['src_nat'] == '(self)': - rule['src_nat'] = '' + if field == "src_nat" and rule["src_nat"] == "(self)": + rule["src_nat"] = "" continue # we check that all exist and are not empty - if not self.parse_host_alias(rule, field, 'rule', name, True): + if not self.parse_host_alias(rule, field, "rule", name, True): ret = False # checking ports - for field in ['src_port', 'dst_port', 'dst_nat_port']: + for field in ["src_port", "dst_port", "dst_nat_port"]: if field in rule: if not isinstance(rule[field], str): rule[field] = str(rule[field]) - if not self.check_port_alias(rule[field], field, 'rule', name) or not self.check_tcp_udp(rule, name): + if not self.check_port_alias( + rule[field], field, "rule", name + ) or not self.check_tcp_udp(rule, name): ret = False - if 'dst_nat_port' in rule and 'dst_nat' not in rule: - self._data.set_error('dst_nat_port field is set on {0} without any dst_nat target'.format(name)) + if "dst_nat_port" in rule and "dst_nat" not in rule: + self._data.set_error( + "dst_nat_port field is set on {0} without any dst_nat target".format( + name + ) + ) ret = False continue - if 'dst_nat' in rule and 'dst_nat_port' not in rule: - self._data.set_error('dst_nat field is set on {0} without any dst_nat_port target'.format(name)) + if "dst_nat" in rule and "dst_nat_port" not in rule: + self._data.set_error( + "dst_nat field is set on {0} without any dst_nat_port target".format( + name + ) + ) ret = False continue # we check that all fields are valid - valid_fields = ['src', 'dst', 'src_port', 'dst_port', 'protocol', 'action', 'floating', 'force'] - valid_fields.extend(['src_nat', 'dst_nat', 'dst_nat_port', 'asymmetric', 'invert_dst', 'invert_src', 'invert_dst_nat', 'invert_src_nat', 'ignored']) + valid_fields = [ + "src", + "dst", + "src_port", + "dst_port", + "protocol", + "action", + "floating", + "force", + ] + valid_fields.extend( + [ + "src_nat", + "dst_nat", + "dst_nat_port", + "asymmetric", + "invert_dst", + "invert_src", + "invert_dst_nat", + "invert_src_nat", + "ignored", + ] + ) valid_fields.extend(OPTION_FIELDS) for field in rule: if field not in valid_fields: - self._data.set_error(field + " is not a valid field name in rule " + name) + self._data.set_error( + field + " is not a valid field name in rule " + name + ) ret = False if name in self._data.rules_obj: display.warning("Rule already defined: {0}".format(name)) - self._data.rules_obj[name] = self.create_obj_rule_from_def(name, rule, parent_separator) + self._data.rules_obj[name] = self.create_obj_rule_from_def( + name, rule, parent_separator + ) if self._data.errors: return False @@ -1773,42 +2045,60 @@ def parse_rules(self, parent=None, parent_separator=None): return ret def parse_target_name(self): - """ Parse target's name definition """ + """Parse target's name definition""" if self._data.target_name not in self._data.pfsenses: - self._data.set_error(self._data.target_name + " does not exist in pfsenses section") + self._data.set_error( + self._data.target_name + " does not exist in pfsenses section" + ) return False self._data.target = self._data.pfsenses_obj[self._data.target_name] return True def check_tcp_udp(self, rule, name): - """ check if protocol is valid when ports are sets """ - if 'protocol' not in rule: + """check if protocol is valid when ports are sets""" + if "protocol" not in rule: return True - protocols = str(rule['protocol']).split() + protocols = str(rule["protocol"]).split() for protocol in protocols: - if protocol != 'udp' and protocol != 'tcp' and protocol != 'tcp/udp': - self._data.set_error(protocol + " protocol used with src_port or dst_port in rule " + name) + if protocol != "udp" and protocol != "tcp" and protocol != "tcp/udp": + self._data.set_error( + protocol + + " protocol used with src_port or dst_port in rule " + + name + ) return False return True def check_pfsense_interfaces_objs(self, interfaces, name): - """ Checking all interfaces networks between them """ + """Checking all interfaces networks between them""" for src_name, src in interfaces.items(): for dst_name, dst in interfaces.items(): if src_name != dst_name and src.local_networks and dst.local_networks: for src_network in src.local_networks: for dst_network in dst.local_networks: if src_network.overlaps(dst_network): - self._data.set_error("Local networks of " + src_name + " and " + dst_name + " overlap in " + name) + self._data.set_error( + "Local networks of " + + src_name + + " and " + + dst_name + + " overlap in " + + name + ) return False return True def create_pfsenses_aliases(self): - """ Generate usefull aliases for pfsenses """ + """Generate usefull aliases for pfsenses""" + def _warn_alias(alias): if len(alias) >= 32: - display.warning("Autogenerated alias {0} is too long and will trigger an error if used".format(alias)) + display.warning( + "Autogenerated alias {0} is too long and will trigger an error if used".format( + alias + ) + ) # if pf_paris has lan and wan interfaces declared, we will generate # - pf_paris_ips with all its ips @@ -1842,13 +2132,12 @@ def _warn_alias(alias): net_all_pfsenses_interfaces_definition = dict() for name, pfsense in self._data.pfsenses.items(): - # maybe do some options for this - ipext = '_ips' - netext = '_nets' + ipext = "_ips" + netext = "_nets" # interfaces field is mandatory - if 'interfaces' not in pfsense: + if "interfaces" not in pfsense: continue # pf_paris_ips @@ -1863,20 +2152,20 @@ def _warn_alias(alias): # pf_paris_lan_nets, pf_paris_tag_nets net_interfaces_definition = dict() - for iname, interface in pfsense['interfaces'].items(): - if 'ip' not in interface: + for iname, interface in pfsense["interfaces"].items(): + if "ip" not in interface: continue - if 'id' in interface and interface.get('id'): - interface_id = str(interface['id']) + if "id" in interface and interface.get("id"): + interface_id = str(interface["id"]) else: interface_id = iname # get the tags tags = list() tags.append(interface_id) - if 'tags' in interface: - for tag in sorted(interface['tags'].split()): + if "tags" in interface: + for tag in sorted(interface["tags"].split()): if tag not in tags: tags.append(tag) @@ -1887,7 +2176,7 @@ def _warn_alias(alias): net_interface_definition = list() # get the ips and networks - for ip in interface['ip'].split(): + for ip in interface["ip"].split(): try: local_network = to_ip_network(to_unicode(ip), False) str_net = str(local_network) @@ -1896,7 +2185,7 @@ def _warn_alias(alias): # we will fail later pass - group = re.match(r'([^\/]*)\/(\d+)', ip) + group = re.match(r"([^\/]*)\/(\d+)", ip) try: if group: ip = to_ip_address(to_unicode(group.group(1))) @@ -1909,13 +2198,18 @@ def _warn_alias(alias): # interface ip alias if interface_definition: for idx, tag in enumerate(tags): - interface_definition_name = name + '_' + tag + ipext + interface_definition_name = name + "_" + tag + ipext - ikey = 'all_' + tag + ipext + ikey = "all_" + tag + ipext if ikey not in all_pfsenses_interfaces_definition: all_pfsenses_interfaces_definition[ikey] = list() - if interface_definition_name not in all_pfsenses_interfaces_definition[ikey]: - all_pfsenses_interfaces_definition[ikey].append(interface_definition_name) + if ( + interface_definition_name + not in all_pfsenses_interfaces_definition[ikey] + ): + all_pfsenses_interfaces_definition[ikey].append( + interface_definition_name + ) if idx == 0: # we only add the interface and not all the tags in the pfsense_definition @@ -1924,18 +2218,25 @@ def _warn_alias(alias): if interface_definition_name not in interfaces_definition: interfaces_definition[interface_definition_name] = list() - interfaces_definition[interface_definition_name].extend(interface_definition) + interfaces_definition[interface_definition_name].extend( + interface_definition + ) # interface network alias if net_interface_definition: for idx, tag in enumerate(tags): - interface_definition_name = name + '_' + tag + netext + interface_definition_name = name + "_" + tag + netext - ikey_net = 'all_' + tag + netext + ikey_net = "all_" + tag + netext if ikey_net not in net_all_pfsenses_interfaces_definition: net_all_pfsenses_interfaces_definition[ikey_net] = list() - if interface_definition_name not in net_all_pfsenses_interfaces_definition[ikey_net]: - net_all_pfsenses_interfaces_definition[ikey_net].append(interface_definition_name) + if ( + interface_definition_name + not in net_all_pfsenses_interfaces_definition[ikey_net] + ): + net_all_pfsenses_interfaces_definition[ikey_net].append( + interface_definition_name + ) if idx == 0: # we only add the interface and not all the tags in the pfsense_definition @@ -1943,8 +2244,12 @@ def _warn_alias(alias): net_interface_definition.sort() if interface_definition_name not in net_interfaces_definition: - net_interfaces_definition[interface_definition_name] = list() - net_interfaces_definition[interface_definition_name].extend(net_interface_definition) + net_interfaces_definition[interface_definition_name] = ( + list() + ) + net_interfaces_definition[interface_definition_name].extend( + net_interface_definition + ) # pf_paris_lan_ips, pf_paris_tag_ips for interface_definition_name in sorted(interfaces_definition.keys()): @@ -1953,21 +2258,23 @@ def _warn_alias(alias): interface_definition.sort() alias = dict() - alias['ip'] = ' '.join(interface_definition) - alias['ignore_dup'] = True + alias["ip"] = " ".join(interface_definition) + alias["ignore_dup"] = True self._data.all_aliases[interface_definition_name] = alias self._data.hosts_aliases[interface_definition_name] = alias _warn_alias(interface_definition_name) # pf_paris_lan_nets, pf_paris_tag_nets for interface_definition_name in sorted(net_interfaces_definition.keys()): - net_interface_definition = net_interfaces_definition[interface_definition_name] + net_interface_definition = net_interfaces_definition[ + interface_definition_name + ] net_interface_definition = list(dict.fromkeys(net_interface_definition)) net_interface_definition.sort() alias = dict() - alias['ip'] = ' '.join(net_interface_definition) - alias['ignore_dup'] = True + alias["ip"] = " ".join(net_interface_definition) + alias["ignore_dup"] = True self._data.all_aliases[interface_definition_name] = alias self._data.hosts_aliases[interface_definition_name] = alias _warn_alias(interface_definition_name) @@ -1978,8 +2285,8 @@ def _warn_alias(alias): all_pfsenses_definition.append(name + ipext) alias = dict() - alias['ip'] = ' '.join(pfsense_definition) - alias['ignore_dup'] = True + alias["ip"] = " ".join(pfsense_definition) + alias["ignore_dup"] = True self._data.all_aliases[name + ipext] = alias self._data.hosts_aliases[name + ipext] = alias _warn_alias(name + ipext) @@ -1990,8 +2297,8 @@ def _warn_alias(alias): net_all_pfsenses_definition.append(name + netext) alias = dict() - alias['ip'] = ' '.join(net_pfsense_definition) - alias['ignore_dup'] = True + alias["ip"] = " ".join(net_pfsense_definition) + alias["ignore_dup"] = True self._data.all_aliases[name + netext] = alias self._data.hosts_aliases[name + netext] = alias _warn_alias(name + netext) @@ -2002,8 +2309,8 @@ def _warn_alias(alias): definition.sort() alias = dict() - alias['ip'] = ' '.join(definition) - alias['ignore_dup'] = True + alias["ip"] = " ".join(definition) + alias["ignore_dup"] = True self._data.all_aliases[name] = alias self._data.hosts_aliases[name] = alias _warn_alias(name) @@ -2014,8 +2321,8 @@ def _warn_alias(alias): definition.sort() alias = dict() - alias['ip'] = ' '.join(definition) - alias['ignore_dup'] = True + alias["ip"] = " ".join(definition) + alias["ignore_dup"] = True self._data.all_aliases[name] = alias self._data.hosts_aliases[name] = alias _warn_alias(name) @@ -2024,25 +2331,25 @@ def _warn_alias(alias): all_pfsenses_definition.sort() alias = dict() - alias['ip'] = ' '.join(all_pfsenses_definition) - alias['ignore_dup'] = True - self._data.all_aliases['all_pfsenses' + ipext] = alias - self._data.hosts_aliases['all_pfsenses' + ipext] = alias + alias["ip"] = " ".join(all_pfsenses_definition) + alias["ignore_dup"] = True + self._data.all_aliases["all_pfsenses" + ipext] = alias + self._data.hosts_aliases["all_pfsenses" + ipext] = alias # generate all network aliases net_all_pfsenses_definition.sort() alias = dict() - alias['ip'] = ' '.join(net_all_pfsenses_definition) - alias['ignore_dup'] = True - self._data.all_aliases['all_pfsenses' + netext] = alias - self._data.hosts_aliases['all_pfsenses' + netext] = alias + alias["ip"] = " ".join(net_all_pfsenses_definition) + alias["ignore_dup"] = True + self._data.all_aliases["all_pfsenses" + netext] = alias + self._data.hosts_aliases["all_pfsenses" + netext] = alias def parse_pfsense_interfaces(self, pfsense, name): - """ Parse all pfsense interfaces definitions """ + """Parse all pfsense interfaces definitions""" ret = {} ids = set() - for iname, interface in pfsense['interfaces'].items(): + for iname, interface in pfsense["interfaces"].items(): # extracting & checking local network local_ips = set() local_networks = set() @@ -2051,27 +2358,42 @@ def parse_pfsense_interfaces(self, pfsense, name): tags = set() tags.add(iname) - if 'tags' in interface: - for tag in interface['tags'].split(): + if "tags" in interface: + for tag in interface["tags"].split(): tags.add(tag) for key in interface: - if key not in ['adjacent_networks', 'remote_networks', 'ip', 'tags', 'id']: - self._data.set_error("Invalid field " + key + " in " + iname + " of " + name) + if key not in [ + "adjacent_networks", + "remote_networks", + "ip", + "tags", + "id", + ]: + self._data.set_error( + "Invalid field " + key + " in " + iname + " of " + name + ) return {} - if 'id' in interface and interface.get('id'): - interface_id = str(interface['id']) + if "id" in interface and interface.get("id"): + interface_id = str(interface["id"]) else: interface_id = iname if interface_id in ids: - self._data.set_error("Duplicate interface id " + interface_id + " in " + iname + " of " + name) + self._data.set_error( + "Duplicate interface id " + + interface_id + + " in " + + iname + + " of " + + name + ) return {} ids.add(interface_id) - if 'ip' in interface: - for ip in interface['ip'].split(): + if "ip" in interface: + for ip in interface["ip"].split(): try: local_network = to_ip_network(to_unicode(ip), False) if first_local_network is None: @@ -2081,11 +2403,16 @@ def parse_pfsense_interfaces(self, pfsense, name): return {} if local_network.prefixlen == 32: - self._data.set_error("Invalid network prefix length for network " + ip + " in " + name) + self._data.set_error( + "Invalid network prefix length for network " + + ip + + " in " + + name + ) return {} # extracting & checking ip - group = re.match(r'([^\/]*)\/(\d+)', ip) + group = re.match(r"([^\/]*)\/(\d+)", ip) try: local_ip = to_ip_address(to_unicode(group.group(1))) if first_local_ip is None: @@ -2098,24 +2425,34 @@ def parse_pfsense_interfaces(self, pfsense, name): # extracting & checking remote networks remote_networks = set() - if 'remote_networks' in interface: - networks = self._data.unalias_ip(interface['remote_networks']) + if "remote_networks" in interface: + networks = self._data.unalias_ip(interface["remote_networks"]) for network in networks: try: remote_networks.add(to_ip_network(to_unicode(network))) except ValueError: - self._data.set_error("Invalid network " + network + " in remote_networks of " + name) + self._data.set_error( + "Invalid network " + + network + + " in remote_networks of " + + name + ) return {} # extracting & checking adjacent networks adjacent_networks = set() - if 'adjacent_networks' in interface: - networks = self._data.unalias_ip(interface['adjacent_networks']) + if "adjacent_networks" in interface: + networks = self._data.unalias_ip(interface["adjacent_networks"]) for network in networks: try: adjacent_networks.add(to_ip_network(to_unicode(network))) except ValueError: - self._data.set_error("Invalid network " + network + " in adjacent_networks of " + name) + self._data.set_error( + "Invalid network " + + network + + " in adjacent_networks of " + + name + ) return {} obj = PFSenseInterface() @@ -2124,7 +2461,7 @@ def parse_pfsense_interfaces(self, pfsense, name): obj.local_network = first_local_network obj.local_ips = local_ips obj.local_networks = local_networks - obj.bridge = (interface.get('bridge')) + obj.bridge = interface.get("bridge") obj.remote_networks = remote_networks obj.adjacent_networks = adjacent_networks obj.tags = tags @@ -2136,17 +2473,17 @@ def parse_pfsense_interfaces(self, pfsense, name): return ret def parse_pfsenses(self): - """ Checking all pfsenses definitions """ + """Checking all pfsenses definitions""" dups = {} ret = True for name, pfsense in self._data.pfsenses.items(): # interfaces field is mandatory - if 'interfaces' not in pfsense: + if "interfaces" not in pfsense: self._data.set_error("No interfaces field for pfsense " + name) ret = False continue - if not pfsense['interfaces']: + if not pfsense["interfaces"]: self._data.set_error("Empty interfaces field for pfsense " + name) ret = False continue @@ -2159,17 +2496,27 @@ def parse_pfsenses(self): # we check that all fields are valid for field in pfsense: - if field != 'interfaces': - self._data.set_error(field + " is not a valid field name in pfsense " + name) + if field != "interfaces": + self._data.set_error( + field + " is not a valid field name in pfsense " + name + ) ret = False # we check for duplicates _pfsense = deepcopy(pfsense) - if 'descr' in _pfsense: - del _pfsense['descr'] + if "descr" in _pfsense: + del _pfsense["descr"] dup = json.dumps(_pfsense) if dup in dups: - display.warning("duplicate pfsense definition for ip " + pfsense['ip'] + " (" + dups[dup] + ", " + name + ")") + display.warning( + "duplicate pfsense definition for ip " + + pfsense["ip"] + + " (" + + dups[dup] + + ", " + + name + + ")" + ) else: dups[dup] = name @@ -2179,14 +2526,14 @@ def parse_pfsenses(self): return ret def parse_hosts_aliases_objs(self): - """ Checking all host alias objs, addresses and finding pfsenses interfaces """ + """Checking all host alias objs, addresses and finding pfsenses interfaces""" for obj in self._data.hosts_aliases_obj.values(): obj.compute_all(self._data) return True def parse(self): - """ Check and parse everything """ + """Check and parse everything""" ret = True self.create_pfsenses_aliases() ret = ret and self.parse_hosts_aliases() @@ -2200,63 +2547,81 @@ def parse(self): class PFSenseRuleDecomposer(object): - """ Class decomposing rules into smaller rules (more suited to pfsense logic ) """ + """Class decomposing rules into smaller rules (more suited to pfsense logic )""" def __init__(self, data): self._data = data def host_separate(self, host): - """ separate aliases to remove mixed configuration + """separate aliases to remove mixed configuration where there is a local and remote network/ip is the host - host is expanded to sub-aliases if required """ + host is expanded to sub-aliases if required""" ret = [] if host.is_whole_not_in_pfsense(self._data.target): if self._data.debug is not None and self._data.debug == host.name: - display.warning('{0}: is_whole_not_in_pfsense {1}'.format(host.name, self._data.target.name)) + display.warning( + "{0}: is_whole_not_in_pfsense {1}".format( + host.name, self._data.target.name + ) + ) ret.append(host) elif host.is_whole_in_pfsense(self._data.target): if self._data.debug is not None and self._data.debug == host.name: - display.warning('{0}: is_whole_in_pfsense {1}'.format(host.name, self._data.target.name)) + display.warning( + "{0}: is_whole_in_pfsense {1}".format( + host.name, self._data.target.name + ) + ) ret.append(host) elif host.is_ip_broadcast(): if self._data.debug is not None and self._data.debug == host.name: - display.warning('{0}: is_ip_broadcast'.format(host.name)) + display.warning("{0}: is_ip_broadcast".format(host.name)) ret.append(host) else: alias = self._data.all_aliases[host.name] - if 'ip' in alias: - for alias_ip in alias['ip'].split(): + if "ip" in alias: + for alias_ip in alias["ip"].split(): ret_n = self.host_separate(self._data.hosts_aliases_obj[alias_ip]) if self._data.debug is not None and self._data.debug == host.name: - display.warning('{0}: host_separate: {1}'.format(host.name, ret_n)) + display.warning( + "{0}: host_separate: {1}".format(host.name, ret_n) + ) ret.extend(ret_n) return ret def host_separate_by_iface(self, host): - """ separate aliases to remove mixed configuration + """separate aliases to remove mixed configuration where there is a local and remote network/ip is the host - host is expanded to sub-aliases if required """ + host is expanded to sub-aliases if required""" ret = [] if host.is_whole_in_same_routing_ifaces(self._data.target): if self._data.debug is not None and self._data.debug == host.name: - display.warning('{0}: is_whole_in_same_routing_ifaces {1}'.format(host.name, self._data.target.name)) + display.warning( + "{0}: is_whole_in_same_routing_ifaces {1}".format( + host.name, self._data.target.name + ) + ) ret.append(host) else: alias = self._data.all_aliases[host.name] - if 'ip' in alias: - for alias_ip in alias['ip'].split(): - ret_n = self.host_separate_by_iface(self._data.hosts_aliases_obj[alias_ip]) + if "ip" in alias: + for alias_ip in alias["ip"].split(): + ret_n = self.host_separate_by_iface( + self._data.hosts_aliases_obj[alias_ip] + ) if self._data.debug is not None and self._data.debug == host.name: - display.warning('{0}: host_separate_by_iface: {1}'.format(host.name, ret_n)) + display.warning( + "{0}: host_separate_by_iface: {1}".format(host.name, ret_n) + ) ret.extend(ret_n) return ret def separate_aliases(self, rule, field, attr, func): - """ Separate aliases from field using func, setting new aliases in attr """ + """Separate aliases from field using func, setting new aliases in attr""" sub_rules = [] function = getattr(self, func) src_sep = function(field) @@ -2269,9 +2634,9 @@ def separate_aliases(self, rule, field, attr, func): return sub_rules def decompose_rule(self, rule): - """ Returns smaller rules from rule """ + """Returns smaller rules from rule""" # A PFSense rule can have only one src or dst - blocking = rule.action != 'pass' + blocking = rule.action != "pass" sub_rules = [] if len(rule.src) > 1 or len(rule.dst) > 1: @@ -2296,18 +2661,22 @@ def decompose_rule(self, rule): # if it's blocking or reject rule, we don't split the destination # since only we only need the source to know how to define the rule - sub_rules = self.separate_aliases(rule, src, 'src', 'host_separate') + sub_rules = self.separate_aliases(rule, src, "src", "host_separate") if not blocking and not sub_rules: - sub_rules = self.separate_aliases(rule, dst, 'dst', 'host_separate') + sub_rules = self.separate_aliases(rule, dst, "dst", "host_separate") if not sub_rules: - sub_rules = self.separate_aliases(rule, src, 'src', 'host_separate_by_iface') + sub_rules = self.separate_aliases( + rule, src, "src", "host_separate_by_iface" + ) if not blocking and not sub_rules: - sub_rules = self.separate_aliases(rule, dst, 'dst', 'host_separate_by_iface') + sub_rules = self.separate_aliases( + rule, dst, "dst", "host_separate_by_iface" + ) return sub_rules def decompose_rules(self): - """ Returns smaller rules (more suited to pfsense logic ) """ + """Returns smaller rules (more suited to pfsense logic )""" for rule in self._data.rules_obj.values(): todo = [] todo.append(rule) @@ -2321,14 +2690,14 @@ def decompose_rules(self): class PFSenseAliasFactory(object): - """ Class generating aliases definitions """ + """Class generating aliases definitions""" def __init__(self, data): self._data = data def add_host_alias_rec(self, alias, aliases): - """ set aliases hosts names to define (recursive) """ - if ':' in alias.name: + """set aliases hosts names to define (recursive)""" + if ":" in alias.name: return name = alias.name @@ -2340,17 +2709,17 @@ def add_host_alias_rec(self, alias, aliases): self.add_host_alias_rec(obj, aliases) def add_port_alias_rec(self, alias, aliases): - """ Return aliases ports names to define (recursive) """ + """Return aliases ports names to define (recursive)""" if alias in self._data.all_aliases: if alias not in aliases: aliases[alias] = self._data.all_aliases[alias] - if 'port' in aliases[alias]: - for port in aliases[alias]['port'].split(): + if "port" in aliases[alias]: + for port in aliases[alias]["port"].split(): self.add_port_alias_rec(port, aliases) def add_hosts_aliases(self, rule, aliases): - """ Return aliases hosts names to define """ + """Return aliases hosts names to define""" for rule_aliases in [rule.src, rule.dst, rule.src_nat, rule.dst_nat]: for alias in rule_aliases: if alias.fake: @@ -2358,7 +2727,7 @@ def add_hosts_aliases(self, rule, aliases): self.add_host_alias_rec(alias, aliases) def add_ports_aliases(self, rule, aliases): - """ Return aliases ports names to define """ + """Return aliases ports names to define""" for alias in rule.src_port: self.add_port_alias_rec(alias, aliases) @@ -2366,7 +2735,7 @@ def add_ports_aliases(self, rule, aliases): self.add_port_alias_rec(alias, aliases) def generate_aliases(self, rule_filter=None): - """ Return aliases definitions for pfsense_aggregate """ + """Return aliases definitions for pfsense_aggregate""" hosts_aliases = {} ports_aliases = {} @@ -2383,79 +2752,91 @@ def generate_aliases(self, rule_filter=None): ret = [] for name, alias in hosts_aliases.items(): definition = {} - definition['name'] = name - definition['type'] = alias['type'] - definition['address'] = ' '.join(alias['ip'].split()) - definition['state'] = 'present' - if 'descr' in alias: - definition['descr'] = alias['descr'] + definition["name"] = name + definition["type"] = alias["type"] + definition["address"] = " ".join(alias["ip"].split()) + definition["state"] = "present" + if "descr" in alias: + definition["descr"] = alias["descr"] else: - definition['descr'] = '' - definition['detail'] = '' + definition["descr"] = "" + definition["detail"] = "" ret.append(definition) for name, alias in ports_aliases.items(): definition = {} - definition['name'] = name - definition['type'] = 'port' - definition['address'] = ' '.join(alias['port'].replace('-', ':').split()) - definition['state'] = 'present' - if 'descr' in alias: - definition['descr'] = alias['descr'] + definition["name"] = name + definition["type"] = "port" + definition["address"] = " ".join(alias["port"].replace("-", ":").split()) + definition["state"] = "present" + if "descr" in alias: + definition["descr"] = alias["descr"] else: - definition['descr'] = '' - definition['detail'] = '' + definition["descr"] = "" + definition["detail"] = "" ret.append(definition) return ret @staticmethod def output_aliases(aliases, ignored_aliases): - """ Output aliases definitions for pfsense_aggregate """ + """Output aliases definitions for pfsense_aggregate""" print(" #===========================") print(" # Hosts & network aliases") print(" # ") definitions = list() for alias in aliases: - if alias['type'] == 'port': + if alias["type"] == "port": continue - definition = " - { name: \"" + alias['name'] + "\", type: \"" + alias['type'] + "\", address: \"" - definition += ' '.join(alias['address'].split()) + "\"" - if 'descr' in alias: - definition = definition + ", descr: \"" + alias['descr'] + "\"" - definition = definition + ", state: \"present\" }" + definition = ( + ' - { name: "' + + alias["name"] + + '", type: "' + + alias["type"] + + '", address: "' + ) + definition += " ".join(alias["address"].split()) + '"' + if "descr" in alias: + definition = definition + ', descr: "' + alias["descr"] + '"' + definition = definition + ', state: "present" }' definitions.append(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) print(" #===========================") print(" # ports aliases") print(" # ") definitions = list() for alias in aliases: - if alias['type'] != 'port': + if alias["type"] != "port": continue - definition = " - { name: \"" + alias['name'] + "\", type: \"port\", address: \"" + ' '.join(alias['address'].split()) + "\"" - if 'descr' in alias: - definition = definition + ", descr: \"" + alias['descr'] + "\"" - definition = definition + ", state: \"present\" }" + definition = ( + ' - { name: "' + + alias["name"] + + '", type: "port", address: "' + + " ".join(alias["address"].split()) + + '"' + ) + if "descr" in alias: + definition = definition + ', descr: "' + alias["descr"] + '"' + definition = definition + ', state: "present" }' definitions.append(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) print(" #===========================") print(" # ignored aliases") print(" # ") definitions = list() for alias in ignored_aliases: - definition = " - { name: \"" + alias + "\" }" + definition = ' - { name: "' + alias + '" }' definitions.append(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) class PFSenseRuleFactory(object): - """ Class generating rules definitions """ + """Class generating rules definitions""" def __init__(self, data, display_warnings=True): self._data = data @@ -2463,20 +2844,20 @@ def __init__(self, data, display_warnings=True): self._display_warnings = display_warnings def rule_interfaces_any(self, rule_obj): - """ Return interfaces set on which the rule is needed to be defined - Manage rules with any src or dst """ + """Return interfaces set on which the rule is needed to be defined + Manage rules with any src or dst""" src = rule_obj.src[0] dst = rule_obj.dst[0] # if rule is forced, we return the interface defined if rule_obj.force: - return set(rule_obj.get_option('ifilter').split()) + return set(rule_obj.get_option("ifilter").split()) - if src.name == 'any' and dst.name == 'any': + if src.name == "any" and dst.name == "any": # we return all interfaces of target return set(self._data.target.interfaces.keys()) - elif src.name == 'any': + elif src.name == "any": # if the destination is local, we return all interfaces of target if dst.is_whole_local(self._data.target): return set(self._data.target.interfaces.keys()) @@ -2484,26 +2865,34 @@ def rule_interfaces_any(self, rule_obj): # otherwise we return all interfaces of target if the destination is adjacent/remote # (we must be able to reach the destination to allow any src to access it) for iface, interface in self._data.target.interfaces.items(): - if dst.is_in_adjacent_networks(interface) or dst.is_in_remote_networks(interface): + if dst.is_in_adjacent_networks(interface) or dst.is_in_remote_networks( + interface + ): return set(self._data.target.interfaces.keys()) return set() - elif rule_obj.dst[0].name == 'any': + elif rule_obj.dst[0].name == "any": # we allow the interfaces matching the source ip/networks # or the adjacent/remote networks interfaces = set() for iface, interface in self._data.target.interfaces.items(): - if src.is_in_local_network(interface) or src.is_in_adjacent_networks(interface) or src.is_in_remote_networks(interface): + if ( + src.is_in_local_network(interface) + or src.is_in_adjacent_networks(interface) + or src.is_in_remote_networks(interface) + ): interfaces.add(iface) if self._data.debug is not None and self._data.debug == rule_obj.name: - display.warning('{0}: to_any_dst interfaces={1}'.format(rule_obj.name, interfaces)) + display.warning( + "{0}: to_any_dst interfaces={1}".format(rule_obj.name, interfaces) + ) return interfaces return None def rule_interfaces_ip_broadcast(self, rule_obj): - """ Return interfaces set on which the rule is needed to be defined - Manage rules with src or dst ip broadcast """ + """Return interfaces set on which the rule is needed to be defined + Manage rules with src or dst ip broadcast""" src = rule_obj.src[0] dst = rule_obj.dst[0] src_is_bcast = src.is_ip_broadcast() @@ -2512,10 +2901,16 @@ def rule_interfaces_ip_broadcast(self, rule_obj): return None if src_is_bcast and rule_obj.dst[0].is_whole_local(self._data.target): - return rule_obj.dst[0].local_interfaces[self._data.target.name] | rule_obj.dst[0].routed_interfaces[self._data.target.name] + return ( + rule_obj.dst[0].local_interfaces[self._data.target.name] + | rule_obj.dst[0].routed_interfaces[self._data.target.name] + ) if dst_is_bcast and rule_obj.src[0].is_whole_local(self._data.target): - return rule_obj.src[0].local_interfaces[self._data.target.name] | rule_obj.src[0].routed_interfaces[self._data.target.name] + return ( + rule_obj.src[0].local_interfaces[self._data.target.name] + | rule_obj.src[0].routed_interfaces[self._data.target.name] + ) # we return no rules for: # - broadcast to broadcast @@ -2524,8 +2919,8 @@ def rule_interfaces_ip_broadcast(self, rule_obj): return [] def bridged_by_interfaces(self, routing_interfaces, dst): - """ if all the routing_interfaces are bridged and the destinations are on local bridges too - return the destination bridges """ + """if all the routing_interfaces are bridged and the destinations are on local bridges too + return the destination bridges""" for iface in routing_interfaces: if not self._data.target.interfaces[iface].bridge: return None @@ -2540,23 +2935,24 @@ def bridged_by_interfaces(self, routing_interfaces, dst): return dst.local_interfaces[self._data.target.name] def rule_interfaces(self, rule_obj): - """ Return interfaces list on which the rule is needed to be defined """ + """Return interfaces list on which the rule is needed to be defined""" + def filter_interfaces(interfaces): if interface_filter is not None: return interface_filter & interfaces return interfaces # if the rule has a filter, apply it - rule_filter = rule_obj.get_option('filter') + rule_filter = rule_obj.get_option("filter") if rule_filter and self._data.target.name not in rule_filter.split(): return set() # if the rule has a efilter, apply it - rule_efilter = rule_obj.get_option('efilter') + rule_efilter = rule_obj.get_option("efilter") if rule_efilter and self._data.target.name in rule_efilter.split(): return set() - interface_filter = rule_obj.get_option('ifilter') + interface_filter = rule_obj.get_option("ifilter") if interface_filter is not None: interface_filter = set(interface_filter.split()) @@ -2564,7 +2960,11 @@ def filter_interfaces(interfaces): raise AssertionError() if self._data.debug is not None and self._data.debug == rule_obj.name: - display.warning('{0}: src={1} dst={2}'.format(rule_obj.name, rule_obj.src[0].name, rule_obj.dst[0].name)) + display.warning( + "{0}: src={1} dst={2}".format( + rule_obj.name, rule_obj.src[0].name, rule_obj.dst[0].name + ) + ) # if the rule uses 'any' interfaces = self.rule_interfaces_any(rule_obj) @@ -2581,14 +2981,20 @@ def filter_interfaces(interfaces): dst_is_local = rule_obj.dst[0].is_whole_local(self._data.target) if self._data.debug is not None and self._data.debug == rule_obj.name: - display.warning('{0}: src_is_local={1} dst_is_local={2}'.format(rule_obj.name, src_is_local, dst_is_local)) + display.warning( + "{0}: src_is_local={1} dst_is_local={2}".format( + rule_obj.name, src_is_local, dst_is_local + ) + ) # if it's a blocking or reject rule, we only use the src - if rule_obj.action != 'pass': + if rule_obj.action != "pass": if src_is_local: interfaces = rule_obj.src[0].local_interfaces[self._data.target.name] else: - interfaces = rule_obj.src[0].routed_by_interfaces(self._data.target, True) + interfaces = rule_obj.src[0].routed_by_interfaces( + self._data.target, True + ) return filter_interfaces(interfaces) @@ -2596,41 +3002,71 @@ def filter_interfaces(interfaces): if src_is_local and dst_is_local: if len(rule_obj.src[0].local_interfaces[self._data.target.name]) != 1: raise AssertionError( - 'Invalid local interfaces count for {0}: {1}' - .format(rule_obj.name, len(rule_obj.src[0].local_interfaces[self._data.target.name]))) + "Invalid local interfaces count for {0}: {1}".format( + rule_obj.name, + len(rule_obj.src[0].local_interfaces[self._data.target.name]), + ) + ) if len(rule_obj.dst[0].local_interfaces[self._data.target.name]) != 1: raise AssertionError( - 'Invalid local interfaces count for {0}: {1}' - .format(rule_obj.name, len(rule_obj.dst[0].local_interfaces[self._data.target.name]))) + "Invalid local interfaces count for {0}: {1}".format( + rule_obj.name, + len(rule_obj.dst[0].local_interfaces[self._data.target.name]), + ) + ) # if they are both on the same interface, we dont need any rule when: # - the interface is not a bridge # - src and dst are in the same network on the interface # - the pfsense is not the source/destination of the rule - src_interface = ''.join(rule_obj.src[0].local_interfaces[self._data.target.name]) - dst_interface = ''.join(rule_obj.dst[0].local_interfaces[self._data.target.name]) - if (src_interface == dst_interface and - not self._data.target.interfaces[src_interface].bridge and - self._data.target.interfaces[src_interface].are_in_same_network(rule_obj.src[0], rule_obj.dst[0])): - if not rule_obj.src[0].match_local_interface_ip(self._data.target) and not rule_obj.dst[0].match_local_interface_ip(self._data.target): + src_interface = "".join( + rule_obj.src[0].local_interfaces[self._data.target.name] + ) + dst_interface = "".join( + rule_obj.dst[0].local_interfaces[self._data.target.name] + ) + if ( + src_interface == dst_interface + and not self._data.target.interfaces[src_interface].bridge + and self._data.target.interfaces[src_interface].are_in_same_network( + rule_obj.src[0], rule_obj.dst[0] + ) + ): + if not rule_obj.src[0].match_local_interface_ip( + self._data.target + ) and not rule_obj.dst[0].match_local_interface_ip(self._data.target): return set() - return filter_interfaces(rule_obj.src[0].local_interfaces[self._data.target.name]) + return filter_interfaces( + rule_obj.src[0].local_interfaces[self._data.target.name] + ) # if the destination is unreachable - if not dst_is_local and src_is_local and not rule_obj.dst[0].is_adjacent_or_remote(self._data.target): + if ( + not dst_is_local + and src_is_local + and not rule_obj.dst[0].is_adjacent_or_remote(self._data.target) + ): if self._display_warnings: display.warning( - 'Destination {0} is not accessible from this pfSense for {1}.Please add the right adjacent/remote network if it\'s not an error' - .format(rule_obj.dst[0].name, rule_obj.name)) + "Destination {0} is not accessible from this pfSense for {1}.Please add the right adjacent/remote network if it's not an error".format( + rule_obj.dst[0].name, rule_obj.name + ) + ) return set() # if the source is unreachable - if not src_is_local and dst_is_local and not rule_obj.src[0].is_adjacent_or_remote(self._data.target): + if ( + not src_is_local + and dst_is_local + and not rule_obj.src[0].is_adjacent_or_remote(self._data.target) + ): if self._display_warnings: display.warning( - 'Source {0} can not access to this pfSense for {1}. Please add the right adjacent/remote network if it\'s not an error' - .format(rule_obj.src[0].name, rule_obj.name)) + "Source {0} can not access to this pfSense for {1}. Please add the right adjacent/remote network if it's not an error".format( + rule_obj.src[0].name, rule_obj.name + ) + ) return set() # we add all the interfaces the source can use to go out @@ -2640,11 +3076,19 @@ def filter_interfaces(interfaces): # we add interfaces the source can use to get in if not src_is_local: src_is_adjacent = rule_obj.src[0].is_adjacent(self._data.target) - routing_interfaces = rule_obj.src[0].routed_by_interfaces(self._data.target, not src_is_adjacent) + routing_interfaces = rule_obj.src[0].routed_by_interfaces( + self._data.target, not src_is_adjacent + ) if self._data.debug is not None and self._data.debug == rule_obj.name: - display.warning('{0}: src_is_adjacent={1}, routing_interfaces={2}, src={3}'.format( - rule_obj.name, src_is_adjacent, routing_interfaces, rule_obj.src[0].name)) + display.warning( + "{0}: src_is_adjacent={1}, routing_interfaces={2}, src={3}".format( + rule_obj.name, + src_is_adjacent, + routing_interfaces, + rule_obj.src[0].name, + ) + ) # if they are both not local and on the same interfaces or with an unreachable destination # we return nothing @@ -2653,132 +3097,159 @@ def filter_interfaces(interfaces): # if the source is remote and the destination is adjacent, we return the source interfaces if not src_is_adjacent and dst_is_adjacent: - if self._data.debug is not None and self._data.debug == rule_obj.name: + if ( + self._data.debug is not None + and self._data.debug == rule_obj.name + ): display.warning( - '{0}: dst_is_adjacent={1}, routing_interfaces={2}, dst={3}'.format( - rule_obj.name, dst_is_adjacent, filter_interfaces(routing_interfaces), rule_obj.dst[0].name)) + "{0}: dst_is_adjacent={1}, routing_interfaces={2}, dst={3}".format( + rule_obj.name, + dst_is_adjacent, + filter_interfaces(routing_interfaces), + rule_obj.dst[0].name, + ) + ) return filter_interfaces(routing_interfaces) # if the source is an adjacent networks, it can get out to reach remote networks - dst_routing_interfaces = rule_obj.dst[0].routed_by_interfaces(self._data.target, src_is_adjacent) + dst_routing_interfaces = rule_obj.dst[0].routed_by_interfaces( + self._data.target, src_is_adjacent + ) if self._data.debug is not None and self._data.debug == rule_obj.name: - display.warning('{0}: dst_is_adjacent={1}, dst_routing_interfaces={2}, dst={3}'.format( - rule_obj.name, dst_is_adjacent, dst_routing_interfaces, rule_obj.dst[0].name)) + display.warning( + "{0}: dst_is_adjacent={1}, dst_routing_interfaces={2}, dst={3}".format( + rule_obj.name, + dst_is_adjacent, + dst_routing_interfaces, + rule_obj.dst[0].name, + ) + ) # if the source is adjacent and the destination is remote, we return the source interfaces if src_is_adjacent and len(dst_routing_interfaces): return filter_interfaces(routing_interfaces) - routing_interfaces = routing_interfaces.difference(dst_routing_interfaces) + routing_interfaces = routing_interfaces.difference( + dst_routing_interfaces + ) if not routing_interfaces or not dst_routing_interfaces: return set() # if the interfaces we would use are bridged, and the destinations are on local bridges too # we declare the rule on the destination bridges since packets would come from there - bridge_interfaces = self.bridged_by_interfaces(routing_interfaces, rule_obj.dst[0]) + bridge_interfaces = self.bridged_by_interfaces( + routing_interfaces, rule_obj.dst[0] + ) if bridge_interfaces: interfaces.update(bridge_interfaces) else: interfaces.update(routing_interfaces) if not interfaces and (src_is_local or dst_is_local): - msg = 'Invalid sub-rule interfaces count ({0}), src={1}, dst={2}'.format(len(interfaces), rule_obj.src[0].name, rule_obj.dst[0].name) + msg = "Invalid sub-rule interfaces count ({0}), src={1}, dst={2}".format( + len(interfaces), rule_obj.src[0].name, rule_obj.dst[0].name + ) raise AssertionError(msg) return filter_interfaces(interfaces) def generate_rule(self, name, rule_obj, interfaces, last_name): - """ Generate rules definitions for rule """ + """Generate rules definitions for rule""" + def _gen_rule_dict(rule_def, name, interface=None): if interface is None: - interface = 'floating' + interface = "floating" floating = True else: floating = False definition = {} - definition['name'] = name - definition['action'] = rule_obj.action + definition["name"] = name + definition["action"] = rule_obj.action for field in OUTPUT_OPTION_FIELDS: value = rule_obj.get_option(field) if value is not None: definition[field] = value if floating: - definition['floating'] = 'yes' - if 'direction' not in definition: - definition['direction'] = 'any' - definition['interface'] = ','.join(rule_interfaces) + definition["floating"] = "yes" + if "direction" not in definition: + definition["direction"] = "any" + definition["interface"] = ",".join(rule_interfaces) else: - definition['interface'] = interface + definition["interface"] = interface - definition['state'] = 'present' + definition["state"] = "present" if interface in last_name and last_name[interface]: - definition['after'] = last_name[interface] + definition["after"] = last_name[interface] else: - definition['after'] = 'top' + definition["after"] = "top" if rule_obj.asymmetric: - definition['statetype'] = 'sloppy state' - definition['tcpflags_any'] = True + definition["statetype"] = "sloppy state" + definition["tcpflags_any"] = True definition.update(rule_def) interfaces[interface].append(definition) last_name[interface] = name - if rule_obj.invert_src and 'source' in definition: - definition['source'] = '!' + definition['source'] + if rule_obj.invert_src and "source" in definition: + definition["source"] = "!" + definition["source"] - if rule_obj.invert_dst and 'destination' in definition: - definition['destination'] = '!' + definition['destination'] + if rule_obj.invert_dst and "destination" in definition: + definition["destination"] = "!" + definition["destination"] if interface not in rule_obj.generated_names: rule_obj.generated_names[interface] = name def _gen_src_nat_rule_dict(rule_def, name, interface, src_nat): definition = {} - definition['descr'] = '{0}_{1}'.format(name, interface) - definition['interface'] = interface - definition['state'] = 'present' - definition['address'] = '{0}'.format(src_nat.name) + definition["descr"] = "{0}_{1}".format(name, interface) + definition["interface"] = interface + definition["state"] = "present" + definition["address"] = "{0}".format(src_nat.name) definition.update(rule_def) for field in OUTPUT_SRC_NAT_OPTION_FIELDS: value = rule_obj.get_option(field) if value is not None: definition[field] = value - for field in ['source', 'destination']: - key = field + '_port' + for field in ["source", "destination"]: + key = field + "_port" if key in definition: - definition[field] = '{0}:{1}'.format(definition[field], definition[key]) + definition[field] = "{0}:{1}".format( + definition[field], definition[key] + ) del definition[key] interfaces[interface].append(definition) def _gen_dst_nat_rule_dict(rule_def, name, interface, dst_nat, dst_nat_port): definition = {} - definition['descr'] = '{0}_{1}'.format(name, interface) - definition['interface'] = interface - definition['state'] = 'present' - definition['target'] = '{0}:{1}'.format(dst_nat.name, dst_nat_port) + definition["descr"] = "{0}_{1}".format(name, interface) + definition["interface"] = interface + definition["state"] = "present" + definition["target"] = "{0}:{1}".format(dst_nat.name, dst_nat_port) definition.update(rule_def) for field in OUTPUT_DST_NAT_OPTION_FIELDS: value = rule_obj.get_option(field) if value is not None: definition[field] = value - if 'associated_rule' not in definition: - definition['associated_rule'] = 'pass' + if "associated_rule" not in definition: + definition["associated_rule"] = "pass" - for field in ['source', 'destination']: - key = field + '_port' + for field in ["source", "destination"]: + key = field + "_port" if key in definition: - definition[field] = '{0}:{1}'.format(definition[field], definition[key]) + definition[field] = "{0}:{1}".format( + definition[field], definition[key] + ) del definition[key] - if rule_obj.invert_src_nat and 'source' in definition: - definition['source'] = '!' + definition['source'] + if rule_obj.invert_src_nat and "source" in definition: + definition["source"] = "!" + definition["source"] - if rule_obj.invert_dst_nat and 'destination' in definition: - definition['destination'] = '!' + definition['destination'] + if rule_obj.invert_dst_nat and "destination" in definition: + definition["destination"] = "!" + definition["destination"] interfaces[interface].append(definition) @@ -2789,37 +3260,41 @@ def _gen_dst_nat_rule_dict(rule_def, name, interface, dst_nat, dst_nat_port): raise AssertionError() rule = {} - rule['src'] = rule_obj.src[0].name - rule['dst'] = rule_obj.dst[0].name + rule["src"] = rule_obj.src[0].name + rule["dst"] = rule_obj.dst[0].name if rule_obj.protocol: - rule['protocol'] = ' '.join(rule_obj.protocol) + rule["protocol"] = " ".join(rule_obj.protocol) if self._data.aggregate: if rule_obj.src_port: if len(rule_obj.src_port) == 1: - rule['src_port'] = ' '.join(rule_obj.src_port) + rule["src_port"] = " ".join(rule_obj.src_port) else: - rule['src_port'] = self._data.get_ports_alias(set(rule_obj.src_port), name) - rule_obj.src_port = [rule['src_port']] + rule["src_port"] = self._data.get_ports_alias( + set(rule_obj.src_port), name + ) + rule_obj.src_port = [rule["src_port"]] if rule_obj.dst_port: if len(rule_obj.dst_port) == 1: - rule['dst_port'] = ' '.join(rule_obj.dst_port) + rule["dst_port"] = " ".join(rule_obj.dst_port) else: - rule['dst_port'] = self._data.get_ports_alias(set(rule_obj.dst_port), name) - rule_obj.dst_port = [rule['dst_port']] + rule["dst_port"] = self._data.get_ports_alias( + set(rule_obj.dst_port), name + ) + rule_obj.dst_port = [rule["dst_port"]] else: if rule_obj.src_port: - rule['src_port'] = ' '.join(rule_obj.src_port) + rule["src_port"] = " ".join(rule_obj.src_port) if rule_obj.dst_port: - rule['dst_port'] = ' '.join(rule_obj.dst_port) + rule["dst_port"] = " ".join(rule_obj.dst_port) - base = rule_product_dict(base, rule, 'src', 'source') - base = rule_product_dict(base, rule, 'dst', 'destination') - base = rule_product_dict(base, rule, 'protocol') - base = rule_product_dict(base, rule, 'src_port', 'source_port') - base = rule_product_dict(base, rule, 'dst_port', 'destination_port') + base = rule_product_dict(base, rule, "src", "source") + base = rule_product_dict(base, rule, "dst", "destination") + base = rule_product_dict(base, rule, "protocol") + base = rule_product_dict(base, rule, "src_port", "source_port") + base = rule_product_dict(base, rule, "dst_port", "destination_port") if rule_obj.floating: rule_interfaces = list(rule_obj.interfaces) @@ -2836,10 +3311,18 @@ def _gen_dst_nat_rule_dict(rule_def, name, interface, dst_nat, dst_nat_port): for interface in rule_obj.interfaces: if len(base) == 1: if rule_obj.src_nat: - _gen_src_nat_rule_dict(base[0], name, interface, rule_obj.src_nat[0]) + _gen_src_nat_rule_dict( + base[0], name, interface, rule_obj.src_nat[0] + ) if rule_obj.dst_nat: - _gen_dst_nat_rule_dict(base[0], name, interface, rule_obj.dst_nat[0], rule_obj.dst_nat_port[0]) + _gen_dst_nat_rule_dict( + base[0], + name, + interface, + rule_obj.dst_nat[0], + rule_obj.dst_nat_port[0], + ) if not rule_obj.src_nat and not rule_obj.dst_nat: _gen_rule_dict(base[0], name, interface) @@ -2848,17 +3331,26 @@ def _gen_dst_nat_rule_dict(rule_def, name, interface, dst_nat, dst_nat_port): for rule_def in base: rule_name = name + "_" + str(rule_idx) if rule_obj.src_nat: - _gen_src_nat_rule_dict(rule_def, rule_name, interface, rule_obj.src_nat[0]) + _gen_src_nat_rule_dict( + rule_def, rule_name, interface, rule_obj.src_nat[0] + ) if rule_obj.dst_nat: - _gen_dst_nat_rule_dict(rule_def, rule_name, interface, rule_obj.dst_nat[0], rule_obj.dst_nat_port[0]) + _gen_dst_nat_rule_dict( + rule_def, + rule_name, + interface, + rule_obj.dst_nat[0], + rule_obj.dst_nat_port[0], + ) if not rule_obj.src_nat and not rule_obj.dst_nat: _gen_rule_dict(rule_def, rule_name, interface) rule_idx = rule_idx + 1 def aggregate_subrules(self, rule, interfaces, subrules, sub_interfaces): - """ aggregate generated subrules """ + """aggregate generated subrules""" + def _get_same_rule(new_rules, src, dst): for rule in new_rules: if rule.src[0].name == src.name and rule.dst[0].name == dst.name: @@ -2875,28 +3367,52 @@ def _aggregate_job(interface=None): src = subrule.src[0] dst = subrule.dst[0] if len(src_group_name) != 1: - src = self._data.get_hosts_alias(src_group_name, src_group_ip, src_group_net, rule.name) - elif (not rule.src_nat and not rule.floating and len(subrule.src[0].networks) == 1 and - len(subrule.src[0].ips) == 0 and - subrule.src[0].networks[0] == self._data.target.interfaces[interface].local_network): + src = self._data.get_hosts_alias( + src_group_name, src_group_ip, src_group_net, rule.name + ) + elif ( + not rule.src_nat + and not rule.floating + and len(subrule.src[0].networks) == 1 + and len(subrule.src[0].ips) == 0 + and subrule.src[0].networks[0] + == self._data.target.interfaces[interface].local_network + ): src = subrule.src[0].copy() src.name = "NET:{0}".format(interface) - elif (not rule.src_nat and not rule.floating and len(subrule.src[0].networks) == 0 and - len(subrule.src[0].ips) == 1 and - subrule.src[0].ips[0] == self._data.target.interfaces[interface].local_ip): + elif ( + not rule.src_nat + and not rule.floating + and len(subrule.src[0].networks) == 0 + and len(subrule.src[0].ips) == 1 + and subrule.src[0].ips[0] + == self._data.target.interfaces[interface].local_ip + ): src = subrule.src[0].copy() src.name = "IP:{0}".format(interface) if len(dst_group_name) != 1: - dst = self._data.get_hosts_alias(dst_group_name, dst_group_ip, dst_group_net, rule.name) - elif (not rule.src_nat and not rule.floating and len(subrule.dst[0].networks) == 1 and - len(subrule.dst[0].ips) == 0 and - subrule.dst[0].networks[0] == self._data.target.interfaces[interface].local_network): + dst = self._data.get_hosts_alias( + dst_group_name, dst_group_ip, dst_group_net, rule.name + ) + elif ( + not rule.src_nat + and not rule.floating + and len(subrule.dst[0].networks) == 1 + and len(subrule.dst[0].ips) == 0 + and subrule.dst[0].networks[0] + == self._data.target.interfaces[interface].local_network + ): dst = subrule.dst[0].copy() dst.name = "NET:{0}".format(interface) - elif (not rule.src_nat and not rule.floating and len(subrule.dst[0].networks) == 0 and - len(subrule.dst[0].ips) == 1 and - subrule.dst[0].ips[0] == self._data.target.interfaces[interface].local_ip): + elif ( + not rule.src_nat + and not rule.floating + and len(subrule.dst[0].networks) == 0 + and len(subrule.dst[0].ips) == 1 + and subrule.dst[0].ips[0] + == self._data.target.interfaces[interface].local_ip + ): dst = subrule.dst[0].copy() dst.name = "IP:{0}".format(interface) @@ -2956,13 +3472,13 @@ def _aggregate_job(interface=None): _add_list(dst_group_net, subrule.dst[0].networks) _aggregate_job(interface) - if rule.floating and 'floating' not in interfaces: - interfaces['floating'] = [] + if rule.floating and "floating" not in interfaces: + interfaces["floating"] = [] subrules.extend(new_rules) def guess_rules(self, rule_filter): - """ Return interfaces, rules and rules names """ + """Return interfaces, rules and rules names""" interfaces = {} rules = list() @@ -2971,7 +3487,9 @@ def guess_rules(self, rule_filter): sub_interfaces = dict() # for each subrule, we guess on which interfaces the subrule needs to be generated, if any - for subrule in sorted(rule.sub_rules, key=lambda x: x.src[0].name + x.dst[0].name): + for subrule in sorted( + rule.sub_rules, key=lambda x: x.src[0].name + x.dst[0].name + ): subrule.interfaces = self.rule_interfaces(subrule) if rule_filter is not None and name != rule_filter: continue @@ -3007,8 +3525,8 @@ def guess_rules(self, rule_filter): return (interfaces, rules) def generate_rules(self, rule_filter=None): - """ Return rules definitions for pfsense_aggregate - if rule_filter, process only rules matching rule_filter + """Return rules definitions for pfsense_aggregate + if rule_filter, process only rules matching rule_filter """ filter_rules = [] @@ -3027,22 +3545,22 @@ def generate_rules(self, rule_filter=None): for rule in rules: self.generate_rule(rule.name, rule, interfaces, last_name) else: - for (name, rule) in rules: + for name, rule in rules: self.generate_rule(name, rule, interfaces, last_name) # since nat is not separated by interface, we manage the order here - last_src_nat = 'top' - last_dst_nat = 'top' + last_src_nat = "top" + last_dst_nat = "top" for name in sorted(interfaces.keys()): interface = interfaces[name] for rule in interface: - if 'address' in rule: - rule['after'] = last_src_nat - last_src_nat = rule['descr'] + if "address" in rule: + rule["after"] = last_src_nat + last_src_nat = rule["descr"] src_nat_rules.append(rule) - elif 'target' in rule: - rule['after'] = last_dst_nat - last_dst_nat = rule['descr'] + elif "target" in rule: + rule["after"] = last_dst_nat + last_dst_nat = rule["descr"] dst_nat_rules.append(rule) else: filter_rules.append(rule) @@ -3050,155 +3568,181 @@ def generate_rules(self, rule_filter=None): return (filter_rules, src_nat_rules, dst_nat_rules) def output_rules(self, rules, ignored_rules): - """ Output rules definitions for pfsense_aggregate """ + """Output rules definitions for pfsense_aggregate""" print(" #===========================") print(" # Rules") print(" # ") interfaces = list(self._data.target.interfaces.keys()) - interfaces.append('floating') + interfaces.append("floating") definitions = list() for interface in interfaces: for rule in rules: - if interface == rule['interface'] or interface == 'floating' and rule.get('floating'): - definition = ' - { name: "%s", source: "%s", ' % (rule['name'], rule['source']) - if 'source_port' in rule: - definition += 'source_port: "{0}", '.format(rule['source_port']) - - definition += 'destination: "{0}", '.format(rule['destination']) - if 'destination_port' in rule: - definition += 'destination_port: "{0}", '.format(rule['destination_port']) - - definition += 'interface: "{0}", action: "{1}"'.format(rule['interface'], rule['action']) - - if rule.get('protocol'): - definition += ", protocol: \"" + rule['protocol'] + "\"" - if rule.get('descr'): - definition += ", descr: \"" + rule['descr'] + "\"" + if ( + interface == rule["interface"] + or interface == "floating" + and rule.get("floating") + ): + definition = ' - { name: "%s", source: "%s", ' % ( + rule["name"], + rule["source"], + ) + if "source_port" in rule: + definition += 'source_port: "{0}", '.format(rule["source_port"]) + + definition += 'destination: "{0}", '.format(rule["destination"]) + if "destination_port" in rule: + definition += 'destination_port: "{0}", '.format( + rule["destination_port"] + ) + + definition += 'interface: "{0}", action: "{1}"'.format( + rule["interface"], rule["action"] + ) + + if rule.get("protocol"): + definition += ', protocol: "' + rule["protocol"] + '"' + if rule.get("descr"): + definition += ', descr: "' + rule["descr"] + '"' for field in OUTPUT_OPTION_FIELDS: value = rule.get(field) if value is not None: - definition += ', {0}: {1}'.format(field, value) + definition += ", {0}: {1}".format(field, value) - if rule.get('floating'): + if rule.get("floating"): definition += ", floating: True" - if rule.get('statetype') is not None: - definition += ", statetype: '{0}'".format(rule.get('statetype')) + if rule.get("statetype") is not None: + definition += ", statetype: '{0}'".format(rule.get("statetype")) - if rule.get('tcpflags_any'): + if rule.get("tcpflags_any"): definition += ", tcpflags_any: True" - if rule.get('after') and not self._data.gendiff: - definition += ", after: \"" + rule['after'] + "\"" - definition += ", state: \"present\" }" + if rule.get("after") and not self._data.gendiff: + definition += ', after: "' + rule["after"] + '"' + definition += ', state: "present" }' if self._data.gendiff: definitions.append(definition) else: print(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) print(" #===========================") print(" # ignored rules") print(" # ") definitions = list() for rule in ignored_rules: - definition = " - { name: \"" + rule + "\" }" + definition = ' - { name: "' + rule + '" }' definitions.append(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) def output_src_nat_rules(self, rules): - """ Output outbound definitions for pfsense_aggregate """ + """Output outbound definitions for pfsense_aggregate""" print(" #===========================") print(" # Nat outbound rules") print(" # ") interfaces = list(self._data.target.interfaces.keys()) - interfaces.append('floating') + interfaces.append("floating") definitions = list() for interface in sorted(interfaces): for rule in rules: - if interface == rule['interface']: - definition = ' - { descr: "%s", source: "%s", ' % (rule['descr'], rule['source']) - if 'source_port' in rule: - definition += 'source_port: "{0}", '.format(rule['source_port']) - - definition += 'destination: "{0}", '.format(rule['destination']) - if 'destination_port' in rule: - definition += 'destination_port: "{0}", '.format(rule['destination_port']) - - definition += 'interface: "{0}", address: "{1}"'.format(rule['interface'], rule['address']) + if interface == rule["interface"]: + definition = ' - { descr: "%s", source: "%s", ' % ( + rule["descr"], + rule["source"], + ) + if "source_port" in rule: + definition += 'source_port: "{0}", '.format(rule["source_port"]) + + definition += 'destination: "{0}", '.format(rule["destination"]) + if "destination_port" in rule: + definition += 'destination_port: "{0}", '.format( + rule["destination_port"] + ) + + definition += 'interface: "{0}", address: "{1}"'.format( + rule["interface"], rule["address"] + ) for field in OUTPUT_SRC_NAT_OPTION_FIELDS: value = rule.get(field) if value is not None: - definition += ', {0}: {1}'.format(field, value) + definition += ", {0}: {1}".format(field, value) - if rule.get('protocol'): - definition += ", protocol: \"" + rule['protocol'] + "\"" - if rule.get('descr'): - definition += ", descr: \"" + rule['descr'] + "\"" + if rule.get("protocol"): + definition += ', protocol: "' + rule["protocol"] + '"' + if rule.get("descr"): + definition += ', descr: "' + rule["descr"] + '"' - if rule.get('after') and not self._data.gendiff: - definition += ", after: \"" + rule['after'] + "\"" - definition += ", state: \"present\" }" + if rule.get("after") and not self._data.gendiff: + definition += ', after: "' + rule["after"] + '"' + definition += ', state: "present" }' if self._data.gendiff: definitions.append(definition) else: print(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) def output_dst_nat_rules(self, rules): - """ Output outbound definitions for pfsense_aggregate """ + """Output outbound definitions for pfsense_aggregate""" print(" #===========================") print(" # Nat port forward rules") print(" # ") interfaces = list(self._data.target.interfaces.keys()) - interfaces.append('floating') + interfaces.append("floating") definitions = list() for interface in sorted(interfaces): for rule in rules: - if interface == rule['interface']: - definition = ' - { descr: "%s", source: "%s", ' % (rule['descr'], rule['source']) - definition += 'destination: "{0}", '.format(rule['destination']) - definition += 'interface: "{0}", target: "{1}"'.format(rule['interface'], rule['target']) + if interface == rule["interface"]: + definition = ' - { descr: "%s", source: "%s", ' % ( + rule["descr"], + rule["source"], + ) + definition += 'destination: "{0}", '.format(rule["destination"]) + definition += 'interface: "{0}", target: "{1}"'.format( + rule["interface"], rule["target"] + ) for field in OUTPUT_SRC_NAT_OPTION_FIELDS: value = rule.get(field) if value is not None: - definition += ', {0}: {1}'.format(field, value) + definition += ", {0}: {1}".format(field, value) - if rule.get('descr'): - definition += ", descr: \"" + rule['descr'] + "\"" + if rule.get("descr"): + definition += ', descr: "' + rule["descr"] + '"' - if rule.get('after') and not self._data.gendiff: - definition += ", after: \"" + rule['after'] + "\"" - definition += ", state: \"present\" }" + if rule.get("after") and not self._data.gendiff: + definition += ', after: "' + rule["after"] + '"' + definition += ', state: "present" }' if self._data.gendiff: definitions.append(definition) else: print(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) class PFSenseRuleSeparatorFactory(object): - """ Class generating rule separators definitions """ + """Class generating rule separators definitions""" def __init__(self, data): self._data = data def _find_first_separator_rule(self, separator): - """ return the name of the first rule in the separator """ + """return the name of the first rule in the separator""" for rule in self._data.rules_obj.values(): for subrule in rule.sub_rules: - if subrule.separator.name == separator.name and separator.interface in subrule.generated_names: + if ( + subrule.separator.name == separator.name + and separator.interface in subrule.generated_names + ): return subrule.generated_names[separator.interface] return None def generate_rule_separators(self, rule_filter=None): - """ Return rule_separators definitions for pfsense_aggregate """ + """Return rule_separators definitions for pfsense_aggregate""" separators = OrderedDict() @@ -3212,7 +3756,7 @@ def generate_rule_separators(self, rule_filter=None): separator = PFSenseRuleSeparator() separator.name = subrule.separator.name if rule.floating: - separator.interface = 'floating' + separator.interface = "floating" else: separator.interface = interface if separator not in separators: @@ -3221,78 +3765,89 @@ def generate_rule_separators(self, rule_filter=None): ret = [] for separator in separators.values(): definition = {} - definition['name'] = separator.name - if separator.interface == 'floating': - definition['floating'] = True + definition["name"] = separator.name + if separator.interface == "floating": + definition["floating"] = True else: - definition['interface'] = separator.interface - definition['before'] = self._find_first_separator_rule(separator) - if definition['before'] is None: + definition["interface"] = separator.interface + definition["before"] = self._find_first_separator_rule(separator) + if definition["before"] is None: # for now we don't manage empty separators continue - definition['state'] = 'present' + definition["state"] = "present" ret.append(definition) return ret def output_rule_separators(self, separators): - """ Output rule separators definitions for pfsense_aggregate """ + """Output rule separators definitions for pfsense_aggregate""" print(" #===========================") print(" # Rule separators") print(" # ") interfaces = list(self._data.target.interfaces.keys()) - interfaces.append('floating') + interfaces.append("floating") definitions = list() for interface in interfaces: for separator in separators: - if 'interface' in separator and interface == separator['interface'] or interface == 'floating' and 'floating' in separator: - definition = " - { name: \"" + separator['name'] + "\", " - if interface == 'floating': + if ( + "interface" in separator + and interface == separator["interface"] + or interface == "floating" + and "floating" in separator + ): + definition = ' - { name: "' + separator["name"] + '", ' + if interface == "floating": definition += "floating: True, " else: - definition += "interface: \"" + separator['interface'] + "\", " - definition += "before: \"" + separator['before'] + "\", state: \"present\" }" + definition += 'interface: "' + separator["interface"] + '", ' + definition += ( + 'before: "' + separator["before"] + '", state: "present" }' + ) definitions.append(definition) definitions.sort() - print('\n'.join(definitions)) + print("\n".join(definitions)) class LookupModule(LookupBase): - """ Lookup module generating pfsense definitions """ + """Lookup module generating pfsense definitions""" def get_hostname(self): - """ Just for easier mock """ - myvars = getattr(self._templar, '_available_variables', {}) - return myvars['inventory_hostname'] + """Just for easier mock""" + myvars = getattr(self._templar, "_available_variables", {}) + return myvars["inventory_hostname"] @staticmethod def get_definitions(from_file): - """ Just for easier mock """ + """Just for easier mock""" return ordered_load(open(from_file), yaml.SafeLoader) def load_data(self, from_file): - """ Load and return pfsense data """ + """Load and return pfsense data""" fvars = self.get_definitions(from_file) if fvars is None: raise AnsibleError("No usable data found in {0}".format(from_file)) - for section in ['hosts_aliases', 'ports_aliases', 'pfsenses', 'rules']: + for section in ["hosts_aliases", "ports_aliases", "pfsenses", "rules"]: if section not in fvars: - raise AnsibleError("Missing {0} section in {1}".format(section, from_file)) + raise AnsibleError( + "Missing {0} section in {1}".format(section, from_file) + ) data = PFSenseData( - hosts_aliases=fvars['hosts_aliases'], - ports_aliases=fvars['ports_aliases'], - pfsenses=fvars['pfsenses'], - rules=fvars['rules'], - target_name=self.get_hostname() + hosts_aliases=fvars["hosts_aliases"], + ports_aliases=fvars["ports_aliases"], + pfsenses=fvars["pfsenses"], + rules=fvars["rules"], + target_name=self.get_hostname(), ) return data def _run(self, terms, variables, **kwargs): - """ Main function """ + """Main function""" if len(terms) != 2: - raise AnsibleError("pfsensible.core.pfsense lookup requires a filename and another parameter in [aliases, rules, rule_separators, all_definitions]") + raise AnsibleError( + "pfsensible.core.pfsense lookup requires a filename and another parameter in [aliases, rules, rule_separators, all_definitions]" + ) data = self.load_data(terms[0]) @@ -3301,38 +3856,38 @@ def _run(self, terms, variables, **kwargs): raise AnsibleError("Error checking pfsense data") alias_factory = PFSenseAliasFactory(data) - rule_factory = PFSenseRuleFactory(data, display_warnings=(terms[1] == 'rules')) + rule_factory = PFSenseRuleFactory(data, display_warnings=(terms[1] == "rules")) rule_separator_factory = PFSenseRuleSeparatorFactory(data) (rules, src_nat_rules, dst_nat_rules) = rule_factory.generate_rules() rule_separators = rule_separator_factory.generate_rule_separators() aliases = alias_factory.generate_aliases() - if terms[1] == 'aliases': + if terms[1] == "aliases": return [aliases] - elif terms[1] == 'rules': + elif terms[1] == "rules": return [rules] - elif terms[1] == 'nat_outbounds': + elif terms[1] == "nat_outbounds": return [src_nat_rules] - elif terms[1] == 'nat_port_forwards': + elif terms[1] == "nat_port_forwards": return [dst_nat_rules] - elif terms[1] == 'rule_separators': + elif terms[1] == "rule_separators": return [rule_separators] - elif terms[1] == 'all_definitions': + elif terms[1] == "all_definitions": res = {} - res['aggregated_aliases'] = aliases - res['aggregated_rules'] = rules - res['aggregated_rule_separators'] = rule_separators - res['aggregated_nat_outbounds'] = src_nat_rules - res['aggregated_nat_port_forwards'] = dst_nat_rules - res['ignored_rules'] = list(data.ignored_rules) - res['ignored_aliases'] = list(data.ignored_aliases) + res["aggregated_aliases"] = aliases + res["aggregated_rules"] = rules + res["aggregated_rule_separators"] = rule_separators + res["aggregated_nat_outbounds"] = src_nat_rules + res["aggregated_nat_port_forwards"] = dst_nat_rules + res["ignored_rules"] = list(data.ignored_rules) + res["ignored_aliases"] = list(data.ignored_aliases) return [res] return [] def run(self, terms, variables, **kwargs): - """ Entry point for main function (to properly catch & display exceptions stacktrace)""" + """Entry point for main function (to properly catch & display exceptions stacktrace)""" trace = None res = [] @@ -3352,15 +3907,15 @@ def run(self, terms, variables, **kwargs): def unit_test_helper(filename, pfname): - """ Unit test helper """ + """Unit test helper""" rule_filter = None fvars = ordered_load(open(filename), yaml.SafeLoader) data = PFSenseData( - hosts_aliases=fvars['hosts_aliases'], - ports_aliases=fvars['ports_aliases'], - pfsenses=fvars['pfsenses'], - rules=fvars['rules'], + hosts_aliases=fvars["hosts_aliases"], + ports_aliases=fvars["ports_aliases"], + pfsenses=fvars["pfsenses"], + rules=fvars["rules"], target_name=pfname, ) @@ -3386,13 +3941,23 @@ def unit_test_helper(filename, pfname): def main(): - """ Output debug helper """ + """Output debug helper""" parser = argparse.ArgumentParser() parser.add_argument("file", help="input file") parser.add_argument("pfsense", help="target_fw") - parser.add_argument('filter', help="rule_name", nargs='?') - parser.add_argument("-a", "--dont-aggregate", action="store_false", help="dont generate aliases to aggregate rules") - parser.add_argument("-g", "--gendiff", action="store_true", help="output more suitable for diffs (debbuging)") + parser.add_argument("filter", help="rule_name", nargs="?") + parser.add_argument( + "-a", + "--dont-aggregate", + action="store_false", + help="dont generate aliases to aggregate rules", + ) + parser.add_argument( + "-g", + "--gendiff", + action="store_true", + help="output more suitable for diffs (debbuging)", + ) parser.add_argument("-d", "--debug-rule", action="store", help="debug rule") args = parser.parse_args() @@ -3400,14 +3965,14 @@ def main(): if args.filter: rule_filter = args.filter - print('Loading data...') + print("Loading data...") fvars = ordered_load(open(args.file), yaml.SafeLoader) data = PFSenseData( - hosts_aliases=fvars['hosts_aliases'], - ports_aliases=fvars['ports_aliases'], - pfsenses=fvars['pfsenses'], - rules=fvars['rules'], + hosts_aliases=fvars["hosts_aliases"], + ports_aliases=fvars["ports_aliases"], + pfsenses=fvars["pfsenses"], + rules=fvars["rules"], target_name=args.pfsense, gendiff=args.gendiff, debug=args.debug_rule, @@ -3415,7 +3980,7 @@ def main(): ) parser = PFSenseDataParser(data) - print('Parsing data...') + print("Parsing data...") if not parser.parse(): return @@ -3423,16 +3988,16 @@ def main(): rule_factory = PFSenseRuleFactory(data) rule_separator_factory = PFSenseRuleSeparatorFactory(data) - print('Generating rules...') + print("Generating rules...") (rules, src_nat_rules, dst_nat_rules) = rule_factory.generate_rules(rule_filter) if rule_filter is None: - print('Generating rule separators...') + print("Generating rule separators...") rule_separators = rule_separator_factory.generate_rule_separators(rule_filter) else: - print('Filter set. Skipping rule separators...') + print("Filter set. Skipping rule separators...") - print('Generating aliases...') + print("Generating aliases...") aliases = alias_factory.generate_aliases(rule_filter) alias_factory.output_aliases(aliases, data.ignored_aliases) @@ -3443,16 +4008,17 @@ def main(): rule_separator_factory.output_rule_separators(rule_separators) -if __name__ == '__main__': +if __name__ == "__main__": profile = False if profile: import cProfile import pstats + profiler = cProfile.Profile() profiler.enable() main() profiler.disable() - stats = pstats.Stats(profiler).sort_stats('tottime') + stats = pstats.Stats(profiler).sort_stats("tottime") stats.print_stats() else: main() diff --git a/plugins/module_utils/__impl/addresses.py b/plugins/module_utils/__impl/addresses.py index 121357bf..aa614a21 100644 --- a/plugins/module_utils/__impl/addresses.py +++ b/plugins/module_utils/__impl/addresses.py @@ -5,22 +5,34 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type try: - from ipaddress import ip_address, ip_network, IPv4Address, IPv6Address, IPv4Network, IPv6Network + from ipaddress import ( + ip_address, + ip_network, + IPv4Address, + IPv6Address, + IPv4Network, + IPv6Network, + ) except ImportError: from ansible_collections.community.general.plugins.module_utils.compat.ipaddress import ( - ip_address, IPv4Address, IPv6Address, - ip_network, IPv4Network, IPv6Network + ip_address, + IPv4Address, + IPv6Address, + ip_network, + IPv4Network, + IPv6Network, ) import re @staticmethod def is_ipv4_address(address): - """ test if address is a valid ipv4 address """ + """test if address is a valid ipv4 address""" try: - addr = ip_address(u'{0}'.format(address)) + addr = ip_address("{0}".format(address)) return isinstance(addr, IPv4Address) except ValueError: pass @@ -29,9 +41,9 @@ def is_ipv4_address(address): @staticmethod def is_ipv6_address(address): - """ test if address is a valid ipv6 address """ + """test if address is a valid ipv6 address""" try: - addr = ip_address(u'{0}'.format(address)) + addr = ip_address("{0}".format(address)) return isinstance(addr, IPv6Address) except ValueError: pass @@ -40,9 +52,9 @@ def is_ipv6_address(address): @staticmethod def is_ipv4_network(address, strict=True): - """ test if address is a valid ipv4 network """ + """test if address is a valid ipv4 network""" try: - addr = ip_network(u'{0}'.format(address), strict=strict) + addr = ip_network("{0}".format(address), strict=strict) return isinstance(addr, IPv4Network) except ValueError: pass @@ -51,9 +63,9 @@ def is_ipv4_network(address, strict=True): @staticmethod def is_ipv6_network(address, strict=True): - """ test if address is a valid ipv6 network """ + """test if address is a valid ipv6 network""" try: - addr = ip_network(u'{0}'.format(address), strict=strict) + addr = ip_network("{0}".format(address), strict=strict) return isinstance(addr, IPv6Network) except ValueError: pass @@ -61,21 +73,23 @@ def is_ipv6_network(address, strict=True): def is_ip_network(self, address, strict=True): - """ test if address is a valid ip network """ - return self.is_ipv4_network(address, strict) or self.is_ipv6_network(address, strict) + """test if address is a valid ip network""" + return self.is_ipv4_network(address, strict) or self.is_ipv6_network( + address, strict + ) def is_within_local_networks(self, address): - """ test if address is contained in our local networks """ + """test if address is contained in our local networks""" networks = self.get_interfaces_networks() try: - addr = ip_address(u'{0}'.format(address)) + addr = ip_address("{0}".format(address)) except ValueError: return False for network in networks: try: - net = ip_network(u'{0}'.format(network), strict=False) + net = ip_network("{0}".format(network), strict=False) if addr in net: return True except ValueError: @@ -85,15 +99,15 @@ def is_within_local_networks(self, address): @staticmethod def parse_ip_network(address, strict=True, returns_ip=True): - """ return cidr parts of address """ + """return cidr parts of address""" try: - addr = ip_network(u'{0}'.format(address), strict=strict) + addr = ip_network("{0}".format(address), strict=strict) if strict or not returns_ip: return (str(addr.network_address), addr.prefixlen) else: # we parse the address with ipaddr just for type checking # but we use a regex to return the result as it dont kept the address bits - group = re.match(r'(.*)/(.*)', address) + group = re.match(r"(.*)/(.*)", address) if group: return (group.group(1), group.group(2)) except ValueError: @@ -102,46 +116,48 @@ def parse_ip_network(address, strict=True, returns_ip=True): def parse_address(self, param, allow_self=True): - """ validate param address field and returns it as a dict """ + """validate param address field and returns it as a dict""" if self.is_ipv6_address(param) or self.is_ipv6_network(param): addr = [param] else: - addr = param.split(':', maxsplit=3) + addr = param.split(":", maxsplit=3) if len(addr) > 3: - self.module.fail_json(msg='Cannot parse address %s' % (param)) + self.module.fail_json(msg="Cannot parse address %s" % (param)) address = addr[0] ret = dict() # Check if the first character is "!" - if address[0] == '!': + if address[0] == "!": # Invert the rule - ret['not'] = None + ret["not"] = None address = address[1:] - if address == 'NET' or address == 'IP': + if address == "NET" or address == "IP": interface = addr[1] if len(addr) > 1 else None ports = addr[2] if len(addr) > 2 else None - if interface is None or interface == '': - self.module.fail_json(msg='Cannot parse address %s' % (param)) + if interface is None or interface == "": + self.module.fail_json(msg="Cannot parse address %s" % (param)) - ret['network'] = self.parse_interface(interface) - if address == 'IP': - ret['network'] += 'ip' + ret["network"] = self.parse_interface(interface) + if address == "IP": + ret["network"] += "ip" else: ports = addr[1] if len(addr) > 1 else None - if address == 'any': - ret['any'] = None + if address == "any": + ret["any"] = None # rule with this firewall - elif allow_self and address == '(self)': - ret['network'] = '(self)' + elif allow_self and address == "(self)": + ret["network"] = "(self)" # rule with interface name (LAN, WAN...) elif self.is_interface_display_name(address): - ret['network'] = self.get_interface_by_display_name(address) + ret["network"] = self.get_interface_by_display_name(address) else: if not self.is_ip_or_alias(address): - self.module.fail_json(msg='Cannot parse address %s, not IP or alias' % (address)) - ret['address'] = address + self.module.fail_json( + msg="Cannot parse address %s, not IP or alias" % (address) + ) + ret["address"] = address if ports is not None: self.parse_port(ports, ret) @@ -152,16 +168,26 @@ def parse_address(self, param, allow_self=True): def parse_port(self, src_ports, ret): - """ validate and parse port address field and set it in ret """ - ports = src_ports.split('-') - if len(ports) > 2 or ports[0] is None or ports[0] == '' or len(ports) == 2 and (ports[1] is None or ports[1] == ''): - self.module.fail_json(msg='Cannot parse port %s' % (src_ports)) + """validate and parse port address field and set it in ret""" + ports = src_ports.split("-") + if ( + len(ports) > 2 + or ports[0] is None + or ports[0] == "" + or len(ports) == 2 + and (ports[1] is None or ports[1] == "") + ): + self.module.fail_json(msg="Cannot parse port %s" % (src_ports)) if not self.is_port_or_alias(ports[0]): - self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[0])) - ret['port'] = ports[0] + self.module.fail_json( + msg="Cannot parse port %s, not port number or alias" % (ports[0]) + ) + ret["port"] = ports[0] if len(ports) > 1: if not self.is_port_or_alias(ports[1]): - self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[1])) - ret['port'] += '-' + ports[1] + self.module.fail_json( + msg="Cannot parse port %s, not port number or alias" % (ports[1]) + ) + ret["port"] += "-" + ports[1] diff --git a/plugins/module_utils/__impl/checks.py b/plugins/module_utils/__impl/checks.py index 5088f1fe..05734abc 100644 --- a/plugins/module_utils/__impl/checks.py +++ b/plugins/module_utils/__impl/checks.py @@ -5,31 +5,39 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import re import socket def check_name(self, name, objtype): - """ check name validity """ + """check name validity""" msg = None - if len(name) >= 32 or len(re.findall(r'(^_*$|^\d*$|[^a-zA-Z0-9_])', name)) > 0: + if len(name) >= 32 or len(re.findall(r"(^_*$|^\d*$|[^a-zA-Z0-9_])", name)) > 0: msg = "The {0} name '{1}' must be less than 32 characters long, may not consist of only numbers, may not consist of only underscores, ".format( - objtype, name) + objtype, name + ) msg += "and may only contain the following characters: a-z, A-Z, 0-9, _" elif name in ["port", "pass"]: - msg = "The {0} name must not be either of the reserved words 'port' or 'pass'".format(objtype) + msg = "The {0} name must not be either of the reserved words 'port' or 'pass'".format( + objtype + ) else: try: socket.getprotobyname(name) - msg = 'The {0} name must not be an IP protocol name such as TCP, UDP, ICMP etc.'.format(objtype) + msg = "The {0} name must not be an IP protocol name such as TCP, UDP, ICMP etc.".format( + objtype + ) except socket.error: pass try: socket.getservbyname(name) - msg = 'The {0} name must not be a well-known or registered TCP or UDP port name such as ssh, smtp, pop3, tftp, http, openvpn etc.'.format(objtype) + msg = "The {0} name must not be a well-known or registered TCP or UDP port name such as ssh, smtp, pop3, tftp, http, openvpn etc.".format( + objtype + ) except socket.error: pass @@ -37,8 +45,10 @@ def check_name(self, name, objtype): self.module.fail_json(msg=msg) -def check_ip_address(self, address, ipprotocol, objtype, allow_networks=False, fail_ifnotip=False): - """ check address according to ipprotocol """ +def check_ip_address( + self, address, ipprotocol, objtype, allow_networks=False, fail_ifnotip=False +): + """check address according to ipprotocol""" if address is None: return if allow_networks: @@ -48,30 +58,34 @@ def check_ip_address(self, address, ipprotocol, objtype, allow_networks=False, f ipv4 = self.is_ipv4_address(address) ipv6 = self.is_ipv6_address(address) - if ipprotocol == 'inet': + if ipprotocol == "inet": if ipv6 or not ipv4 and fail_ifnotip: - self.module.fail_json(msg='{0} must use an IPv4 address'.format(objtype)) - elif ipprotocol == 'inet6': + self.module.fail_json(msg="{0} must use an IPv4 address".format(objtype)) + elif ipprotocol == "inet6": if ipv4 or not ipv6 and fail_ifnotip: - self.module.fail_json(msg='{0} must use an IPv6 address'.format(objtype)) - elif ipprotocol == 'inet46': + self.module.fail_json(msg="{0} must use an IPv6 address".format(objtype)) + elif ipprotocol == "inet46": if ipv4 or ipv6: - self.module.fail_json(msg='IPv4 and IPv6 addresses can not be used in objects that apply to both IPv4 and IPv6 (except within an alias).') + self.module.fail_json( + msg="IPv4 and IPv6 addresses can not be used in objects that apply to both IPv4 and IPv6 (except within an alias)." + ) def validate_openvpn_tunnel_network(self, network, ipproto): - """ check openvpn tunnel network validity - based on pfSense's openvpn_validate_tunnel_network() """ - if network is not None and network != '': - alias_elt = self.find_alias(network, aliastype='network') + """check openvpn tunnel network validity - based on pfSense's openvpn_validate_tunnel_network()""" + if network is not None and network != "": + alias_elt = self.find_alias(network, aliastype="network") if alias_elt is not None: - networks = alias_elt.find('address').text.split() + networks = alias_elt.find("address").text.split() if len(networks) > 1: - self.module.fail_json("The alias {0} contains more than one network".format(network)) + self.module.fail_json( + "The alias {0} contains more than one network".format(network) + ) network = networks[0] - if not self.is_ipv4_network(network, strict=False) and ipproto == 'ipv4': + if not self.is_ipv4_network(network, strict=False) and ipproto == "ipv4": self.module.fail_json("{0} is not a valid IPv4 network".format(network)) - if not self.is_ipv6_network(network, strict=False) and ipproto == 'ipv6': + if not self.is_ipv6_network(network, strict=False) and ipproto == "ipv6": self.module.fail_json("{0} is not a valid IPv6 network".format(network)) return True @@ -79,7 +93,9 @@ def validate_openvpn_tunnel_network(self, network, ipproto): def validate_string(self, name, objtype): - """ check string validity - similar to pfSense's do_input_validate() """ + """check string validity - similar to pfSense's do_input_validate()""" - if len(re.findall(r'[\000-\010\013\014\016-\037]', name)) > 0: - self.module.fail_json("The {0} name contains invalid characters.".format(objtype)) + if len(re.findall(r"[\000-\010\013\014\016-\037]", name)) > 0: + self.module.fail_json( + "The {0} name contains invalid characters.".format(objtype) + ) diff --git a/plugins/module_utils/__impl/interfaces.py b/plugins/module_utils/__impl/interfaces.py index d3687376..3f61af18 100644 --- a/plugins/module_utils/__impl/interfaces.py +++ b/plugins/module_utils/__impl/interfaces.py @@ -4,39 +4,40 @@ # Copyright: (c) 2019, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type def get_interface_by_display_name(self, name): - """ return interface_id by name """ + """return interface_id by name""" for interface in self.interfaces: - descr_elt = interface.find('descr') + descr_elt = interface.find("descr") if descr_elt is not None and descr_elt.text.strip().lower() == name.lower(): return interface.tag return None def get_interface_by_port(self, name): - """ return interface_id by port (os name) """ + """return interface_id by port (os name)""" for interface in self.interfaces: - if interface.find('if').text.strip() == name: + if interface.find("if").text.strip() == name: return interface.tag return None def get_interface_display_name(self, interface_id, return_none=False): - """ return interface display name if found, otherwhise return the interface_id """ - if interface_id == 'enc0': - return 'IPsec' - if interface_id == 'openvpn': + """return interface display name if found, otherwhise return the interface_id""" + if interface_id == "enc0": + return "IPsec" + if interface_id == "openvpn": if return_none and not self.is_openvpn_enabled(): return None - return 'OpenVPN' + return "OpenVPN" for interface in self.interfaces: if interface.tag == interface_id: - descr_elt = interface.find('descr') + descr_elt = interface.find("descr") if descr_elt is not None: return descr_elt.text.strip() break @@ -47,7 +48,7 @@ def get_interface_display_name(self, interface_id, return_none=False): def get_interface_elt(self, interface_id): - """ return interface """ + """return interface""" for interface in self.interfaces: if interface.tag == interface_id: return interface @@ -55,45 +56,55 @@ def get_interface_elt(self, interface_id): def get_interface_port(self, interface_id): - """ return interface port """ + """return interface port""" for interface in self.interfaces: if interface.tag == interface_id: - return interface.find('if').text.strip() + return interface.find("if").text.strip() return None def get_interface_port_by_display_name(self, name): - """ return interface port """ + """return interface port""" for interface in self.interfaces: - descr_elt = interface.find('descr') + descr_elt = interface.find("descr") if descr_elt is not None and descr_elt.text.strip().lower() == name.lower(): - return interface.find('if').text.strip() + return interface.find("if").text.strip() return None def get_interfaces_networks(self): - """ return interface local networks """ + """return interface local networks""" ret = [] for interface in self.interfaces: - if interface.find('enable') is None: + if interface.find("enable") is None: continue - ipaddr_elt = interface.find('ipaddr') - subnet_elt = interface.find('subnet') - if ipaddr_elt is not None and subnet_elt is not None and ipaddr_elt.text is not None and subnet_elt.text is not None: - ret.append('{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text)) - - ipaddr_elt = interface.find('ipaddrv6') - subnet_elt = interface.find('subnetv6') - if ipaddr_elt is not None and subnet_elt is not None and ipaddr_elt.text is not None and subnet_elt.text is not None: - ret.append('{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text)) + ipaddr_elt = interface.find("ipaddr") + subnet_elt = interface.find("subnet") + if ( + ipaddr_elt is not None + and subnet_elt is not None + and ipaddr_elt.text is not None + and subnet_elt.text is not None + ): + ret.append("{0}/{1}".format(ipaddr_elt.text, subnet_elt.text)) + + ipaddr_elt = interface.find("ipaddrv6") + subnet_elt = interface.find("subnetv6") + if ( + ipaddr_elt is not None + and subnet_elt is not None + and ipaddr_elt.text is not None + and subnet_elt.text is not None + ): + ret.append("{0}/{1}".format(ipaddr_elt.text, subnet_elt.text)) # TODO: add vip networks return ret def is_interface_port(self, interface_port): - """ determines if arg is a pfsense interface port or not """ + """determines if arg is a pfsense interface port or not""" for interface in self.interfaces: interface_elt = interface.tag.strip() if interface_elt == interface_port: @@ -102,9 +113,9 @@ def is_interface_port(self, interface_port): def is_interface_display_name(self, name): - """ determines if arg is an interface name or not """ + """determines if arg is an interface name or not""" for interface in self.interfaces: - descr_elt = interface.find('descr') + descr_elt = interface.find("descr") if descr_elt is not None: if descr_elt.text.strip().lower() == name.lower(): return True @@ -112,10 +123,10 @@ def is_interface_display_name(self, name): def is_interface_group(self, name): - """ determines if arg is an interface group name or not """ + """determines if arg is an interface group name or not""" if self.ifgroups is not None: for interface in self.ifgroups: - ifname_elt = interface.find('ifname') + ifname_elt = interface.find("ifname") if ifname_elt is not None: # ifgroup names appear to be case sensitive if ifname_elt.text.strip() == name: @@ -124,11 +135,19 @@ def is_interface_group(self, name): def parse_interface(self, interface, fail=True, with_virtual=True, with_gwgroup=False): - """ validate param interface field """ - if with_virtual and (interface == 'enc0' or interface.lower() == 'ipsec') and self.is_ipsec_enabled(): - return 'enc0' - if with_virtual and (interface == 'openvpn' or interface.lower() == 'openvpn') and self.is_openvpn_enabled(): - return 'openvpn' + """validate param interface field""" + if ( + with_virtual + and (interface == "enc0" or interface.lower() == "ipsec") + and self.is_ipsec_enabled() + ): + return "enc0" + if ( + with_virtual + and (interface == "openvpn" or interface.lower() == "openvpn") + and self.is_openvpn_enabled() + ): + return "openvpn" if with_gwgroup and self.is_gateway_group(interface): return interface @@ -140,5 +159,5 @@ def parse_interface(self, interface, fail=True, with_virtual=True, with_gwgroup= return interface if fail: - self.module.fail_json(msg='%s is not a valid interface' % (interface)) + self.module.fail_json(msg="%s is not a valid interface" % (interface)) return None diff --git a/plugins/module_utils/alias.py b/plugins/module_utils/alias.py index b6554f7b..df1c0dd3 100644 --- a/plugins/module_utils/alias.py +++ b/plugins/module_utils/alias.py @@ -5,22 +5,28 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) ALIAS_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - state=dict(default='present', choices=['present', 'absent']), - type=dict(required=False, choices=['host', 'network', 'port', 'urltable', 'urltable_ports']), - address=dict(default=None, required=False, type='str'), - url=dict(default=None, required=False, type='str'), - descr=dict(default=None, required=False, type='str'), - detail=dict(default=None, required=False, type='str'), - updatefreq=dict(default=None, required=False, type='int'), + name=dict(required=True, type="str"), + state=dict(default="present", choices=["present", "absent"]), + type=dict( + required=False, + choices=["host", "network", "port", "urltable", "urltable_ports"], + ), + address=dict(default=None, required=False, type="str"), + url=dict(default=None, required=False, type="str"), + descr=dict(default=None, required=False, type="str"), + detail=dict(default=None, required=False, type="str"), + updatefreq=dict(default=None, required=False, type="int"), ) ALIAS_MUTUALLY_EXCLUSIVE = [ - ('address', 'url'), + ("address", "url"), ] ALIAS_REQUIRED_IF = [ @@ -41,8 +47,8 @@ ] ALIAS_CREATE_DEFAULT = dict( - descr='', - detail='', + descr="", + detail="", ) ALIAS_PHP_COMMAND_SET = """ @@ -52,7 +58,7 @@ class PFSenseAliasModule(PFSenseModuleBase): - """ module managing pfsense aliases """ + """module managing pfsense aliases""" ############################## # unit tests @@ -60,15 +66,23 @@ class PFSenseAliasModule(PFSenseModuleBase): # Must be class method for unit test usage @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return ALIAS_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseAliasModule, self).__init__(module, pfsense, root='aliases', node='alias', key='name', update_php=ALIAS_PHP_COMMAND_SET, - map_param_if=ALIAS_MAP_PARAM_IF, create_default=ALIAS_CREATE_DEFAULT) + super(PFSenseAliasModule, self).__init__( + module, + pfsense, + root="aliases", + node="alias", + key="name", + update_php=ALIAS_PHP_COMMAND_SET, + map_param_if=ALIAS_MAP_PARAM_IF, + create_default=ALIAS_CREATE_DEFAULT, + ) # Override for use with aggregate self.argument_spec = ALIAS_ARGUMENT_SPEC @@ -76,40 +90,66 @@ def __init__(self, module, pfsense=None): # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - self.pfsense.check_name(params['name'], 'alias') + self.pfsense.check_name(params["name"], "alias") - if params['state'] == 'present': + if params["state"] == "present": # the GUI does not allow to create 2 aliases with same name and differents types - alias_elt = self.pfsense.find_alias(params['name']) + alias_elt = self.pfsense.find_alias(params["name"]) if alias_elt is not None: - if params['type'] not in ['host', 'network'] or alias_elt.find('type').text not in ['host', 'network']: - if params['type'] != alias_elt.find('type').text: - self.module.fail_json(msg='An alias with this name and a different type already exists: \'{0}\''.format(params['name'])) + if params["type"] not in ["host", "network"] or alias_elt.find( + "type" + ).text not in ["host", "network"]: + if params["type"] != alias_elt.find("type").text: + self.module.fail_json( + msg="An alias with this name and a different type already exists: '{0}'".format( + params["name"] + ) + ) # Aliases cannot have the same name as an interface description - if self.pfsense.get_interface_by_display_name(params['name']) is not None: - self.module.fail_json(msg='An interface description with this name already exists: \'{0}\''.format(params['name'])) + if self.pfsense.get_interface_by_display_name(params["name"]) is not None: + self.module.fail_json( + msg="An interface description with this name already exists: '{0}'".format( + params["name"] + ) + ) # updatefreq is for urltable only - if params['updatefreq'] is not None and params['type'] != 'urltable' and params['type'] != 'urltable_ports': - self.module.fail_json(msg='updatefreq is only valid with type urltable or urltable_ports') - - details = params['detail'].split('||') if params['detail'] is not None else [] - if params['address'] is not None: + if ( + params["updatefreq"] is not None + and params["type"] != "urltable" + and params["type"] != "urltable_ports" + ): + self.module.fail_json( + msg="updatefreq is only valid with type urltable or urltable_ports" + ) + + details = ( + params["detail"].split("||") if params["detail"] is not None else [] + ) + if params["address"] is not None: # check details count - addresses = params['address'].split(' ') + addresses = params["address"].split(" ") if len(details) > len(addresses): - self.module.fail_json(msg='Too many details in relation to addresses') + self.module.fail_json( + msg="Too many details in relation to addresses" + ) # warn if address is used with urltable to urltable_ports - if params['type'] in ['urltable', 'urltable_ports']: - self.module.warn('Use of "address" with {type} is depracated, please use "url" instead'.format(type=params['type'])) + if params["type"] in ["urltable", "urltable_ports"]: + self.module.warn( + 'Use of "address" with {type} is depracated, please use "url" instead'.format( + type=params["type"] + ) + ) # pfSense GUI rule for detail in details: - if detail.startswith('|') or detail.endswith('|'): - self.module.fail_json(msg='Vertical bars (|) at start or end of descriptions not allowed') + if detail.startswith("|") or detail.endswith("|"): + self.module.fail_json( + msg="Vertical bars (|) at start or end of descriptions not allowed" + ) diff --git a/plugins/module_utils/arg_route.py b/plugins/module_utils/arg_route.py index 063b65f8..25034e62 100644 --- a/plugins/module_utils/arg_route.py +++ b/plugins/module_utils/arg_route.py @@ -3,7 +3,8 @@ # Copyright: (c) 2024, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type @@ -14,12 +15,15 @@ def p2o_cert(self, name, params, obj): obj[name] = self.pfsense.get_certref(params[name]) + def p2o_interface(self, name, params, obj): obj[name] = self.pfsense.parse_interface(params[name], with_virtual=True) def p2o_interface_with_gwgroup(self, name, params, obj): - obj[name] = self.pfsense.parse_interface(params[name], with_virtual=False, with_gwgroup=True) + obj[name] = self.pfsense.parse_interface( + params[name], with_virtual=False, with_gwgroup=True + ) def p2o_interface_without_virtual(self, name, params, obj): diff --git a/plugins/module_utils/default_gateway.py b/plugins/module_utils/default_gateway.py index 4fa4d3cf..3f6c3b8c 100644 --- a/plugins/module_utils/default_gateway.py +++ b/plugins/module_utils/default_gateway.py @@ -6,30 +6,35 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) DEFAULT_GATEWAY_ARGUMENT_SPEC = dict( - gateway=dict(type='str'), - ipprotocol=dict(default='inet', choices=['inet', 'inet6']), + gateway=dict(type="str"), + ipprotocol=dict(default="inet", choices=["inet", "inet6"]), ) class PFSenseDefaultGatewayModule(PFSenseModuleBase): - """ module managing pfsense default gateways """ + """module managing pfsense default gateways""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return DEFAULT_GATEWAY_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseDefaultGatewayModule, self).__init__(module, pfsense, root='gateways') + super(PFSenseDefaultGatewayModule, self).__init__( + module, pfsense, root="gateways" + ) self.name = "pfsense_default_gateway" self.target_elt = self.root_elt self.interface_elt = None @@ -39,7 +44,7 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params + """return a dict from module params gateway required, str ipprotocol default : inet, choice inet/inet6 """ @@ -49,25 +54,29 @@ def _params_to_obj(self): # Modification if params["gateway"]: - my_defaultgw = self._gw2machine(params['gateway']) - if params['ipprotocol'] == "inet": - obj['defaultgw4'] = my_defaultgw + my_defaultgw = self._gw2machine(params["gateway"]) + if params["ipprotocol"] == "inet": + obj["defaultgw4"] = my_defaultgw self.result["defaultgw4"] = params["gateway"] - elif params['ipprotocol'] == "inet6": - obj['defaultgw6'] = my_defaultgw + elif params["ipprotocol"] == "inet6": + obj["defaultgw6"] = my_defaultgw self.result["defaultgw6"] = params["gateway"] else: - self.module.fail_json(msg='Please specify a valid ipprotocol (inet/inet6)') + self.module.fail_json( + msg="Please specify a valid ipprotocol (inet/inet6)" + ) return obj def _validate_params(self): - """ do some extra checks on input parameters + """do some extra checks on input parameters gateway required, str ipprotocol default : inet, choice inet/inet6 """ params = self.params - gateway_list = ["none", "automatic"] + [gw["Name"] for gw in self.pfsense.find_active_gateways()] + gateway_list = ["none", "automatic"] + [ + gw["Name"] for gw in self.pfsense.find_active_gateways() + ] # get list of current default gateways and append gateway_groups to list for elt in self.root_elt: @@ -80,17 +89,19 @@ def _validate_params(self): if params["gateway"]: if str(params["gateway"]) not in gateway_list: - self.module.fail_json(msg="Unknown gateway %s : %s" % (params["gateway"], gateway_list)) + self.module.fail_json( + msg="Unknown gateway %s : %s" % (params["gateway"], gateway_list) + ) ############################## # XML processing # def _create_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" if self.params["ipprotocol"] == "inet": - return self.pfsense.new_element('defaultgw4') + return self.pfsense.new_element("defaultgw4") elif self.params["ipprotocol"] == "inet6": - return self.pfsense.new_element('defaultgw6') + return self.pfsense.new_element("defaultgw6") ############################## # Utilities @@ -126,12 +137,12 @@ def _gw2human(gateway): @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ + """returns the list of params to remove if they are not set""" return [] ############################## def run(self, params): - """ process input params to add/update/delete """ + """process input params to add/update/delete""" self.params = params self._check_deprecated_params() self._check_onward_params() @@ -143,8 +154,9 @@ def run(self, params): self._add() def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell(''' + """make the target pfsense reload""" + return self.pfsense.phpshell( + """ require_once("filter.inc"); $retval = 0; @@ -157,21 +169,26 @@ def _update(self): send_event("service reload dyndnsall"); if ($retval == 0) clear_subsystem_dirty('staticroutes'); -''') +""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ + """return obj's name""" return "" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if self.params["ipprotocol"] == "inet": - values += self.format_updated_cli_field(self.obj, before, 'defaultgw4', add_comma=values) + values += self.format_updated_cli_field( + self.obj, before, "defaultgw4", add_comma=values + ) elif self.params["ipprotocol"] == "inet6": - values += self.format_updated_cli_field(self.obj, before, 'defaultgw6', add_comma=values) + values += self.format_updated_cli_field( + self.obj, before, "defaultgw6", add_comma=values + ) return values diff --git a/plugins/module_utils/dhcp_server.py b/plugins/module_utils/dhcp_server.py index c242b9a9..9b39d532 100644 --- a/plugins/module_utils/dhcp_server.py +++ b/plugins/module_utils/dhcp_server.py @@ -4,58 +4,74 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type from ipaddress import ip_address, ip_network import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) DHCPSERVER_ARGUMENT_SPEC = dict( - state=dict(type='str', default='present', choices=['present', 'absent']), - interface=dict(required=True, type='str'), - enable=dict(type='bool', default=True), - range_from=dict(type='str'), - range_to=dict(type='str'), - failover_peerip=dict(type='str'), - defaultleasetime=dict(type='int'), - maxleasetime=dict(type='int'), - netmask=dict(type='str'), - gateway=dict(type='str'), - domain=dict(type='str'), - domainsearchlist=dict(type='str'), - ddnsdomain=dict(type='str'), - ddnsdomainprimary=dict(type='str'), - ddnsdomainkeyname=dict(type='str', no_log=False), - ddnsdomainkeyalgorithm=dict(type='str', default='hmac-md5', choices=['hmac-md5', 'hmac-sha1', 'hmac-sha224', 'hmac-sha256', 'hmac-sha384', 'hmac-sha512']), - ddnsdomainkey=dict(type='str', no_log=True), - mac_allow=dict(type='list', elements='str'), - mac_deny=dict(type='list', elements='str'), - ddnsclientupdates=dict(type='str', default='allow', choices=['allow', 'deny', 'ignore']), - tftp=dict(type='str'), - ldap=dict(type='str'), - nextserver=dict(type='str'), - filename=dict(type='str'), - filename32=dict(type='str'), - filename64=dict(type='str'), - rootpath=dict(type='str'), - numberoptions=dict(type='str'), - winsserver=dict(type='list', elements='str'), - dnsserver=dict(type='list', elements='str'), - ntpserver=dict(type='list', elements='str'), - ignorebootp=dict(type='bool'), - denyunknown=dict(type='str', choices=['disabled', 'enabled', 'class']), - nonak=dict(type='bool'), - ignoreclientuids=dict(type='bool'), - staticarp=dict(type='bool'), - dhcpinlocaltime=dict(type='bool'), - statsgraph=dict(type='bool'), - disablepingcheck=dict(type='bool'), + state=dict(type="str", default="present", choices=["present", "absent"]), + interface=dict(required=True, type="str"), + enable=dict(type="bool", default=True), + range_from=dict(type="str"), + range_to=dict(type="str"), + failover_peerip=dict(type="str"), + defaultleasetime=dict(type="int"), + maxleasetime=dict(type="int"), + netmask=dict(type="str"), + gateway=dict(type="str"), + domain=dict(type="str"), + domainsearchlist=dict(type="str"), + ddnsdomain=dict(type="str"), + ddnsdomainprimary=dict(type="str"), + ddnsdomainkeyname=dict(type="str", no_log=False), + ddnsdomainkeyalgorithm=dict( + type="str", + default="hmac-md5", + choices=[ + "hmac-md5", + "hmac-sha1", + "hmac-sha224", + "hmac-sha256", + "hmac-sha384", + "hmac-sha512", + ], + ), + ddnsdomainkey=dict(type="str", no_log=True), + mac_allow=dict(type="list", elements="str"), + mac_deny=dict(type="list", elements="str"), + ddnsclientupdates=dict( + type="str", default="allow", choices=["allow", "deny", "ignore"] + ), + tftp=dict(type="str"), + ldap=dict(type="str"), + nextserver=dict(type="str"), + filename=dict(type="str"), + filename32=dict(type="str"), + filename64=dict(type="str"), + rootpath=dict(type="str"), + numberoptions=dict(type="str"), + winsserver=dict(type="list", elements="str"), + dnsserver=dict(type="list", elements="str"), + ntpserver=dict(type="list", elements="str"), + ignorebootp=dict(type="bool"), + denyunknown=dict(type="str", choices=["disabled", "enabled", "class"]), + nonak=dict(type="bool"), + ignoreclientuids=dict(type="bool"), + staticarp=dict(type="bool"), + dhcpinlocaltime=dict(type="bool"), + statsgraph=dict(type="bool"), + disablepingcheck=dict(type="bool"), ) class PFSenseDHCPServerModule(PFSenseModuleBase): - """ module managing pfsense DHCP server settings """ + """module managing pfsense DHCP server settings""" @staticmethod def get_argument_spec(): @@ -70,7 +86,7 @@ def __init__(self, module, pfsense=None): self.name = "pfsense_dhcp_server" self.obj = dict() - self.root_elt = self.pfsense.get_element('dhcpd', create_node=True) + self.root_elt = self.pfsense.get_element("dhcpd", create_node=True) self.target = None self.network = None @@ -85,13 +101,16 @@ def _get_logical_interface(self, interface): return iface.tag # Check if it matches the physical interface name (e.g., 'em0', 'igb0') - if_elt = iface.find('if') + if_elt = iface.find("if") if if_elt is not None and if_elt.text.strip().lower() == interface.lower(): return iface.tag # Check if it matches the interface description - descr_elt = iface.find('descr') - if descr_elt is not None and descr_elt.text.strip().lower() == interface.lower(): + descr_elt = iface.find("descr") + if ( + descr_elt is not None + and descr_elt.text.strip().lower() == interface.lower() + ): return iface.tag return None @@ -99,18 +118,22 @@ def _get_logical_interface(self, interface): def _is_valid_netif(self, netif): for nic in self.pfsense.interfaces: if nic.tag == netif: - if nic.find('ipaddr') is not None: - ipaddr = nic.find('ipaddr').text + if nic.find("ipaddr") is not None: + ipaddr = nic.find("ipaddr").text if ipaddr is not None: - if nic.find('subnet') is not None: - subnet = int(nic.find('subnet').text) + if nic.find("subnet") is not None: + subnet = int(nic.find("subnet").text) if subnet < 31: - self.network = ip_network(u'{0}/{1}'.format(ipaddr, subnet), strict=False) + self.network = ip_network( + "{0}/{1}".format(ipaddr, subnet), strict=False + ) return True return False def _is_valid_macaddr(self, macaddr): - return bool(re.fullmatch(r'(?:[0-9a-fA-F]{2}[:-]){5}[0-9a-fA-F]{2}', macaddr, re.I)) + return bool( + re.fullmatch(r"(?:[0-9a-fA-F]{2}[:-]){5}[0-9a-fA-F]{2}", macaddr, re.I) + ) def _params_to_obj(self): """return a dict from module params""" @@ -119,44 +142,71 @@ def _params_to_obj(self): obj = dict() self.obj = obj - if params['state'] == 'present': - - self._get_ansible_param(obj, 'range', force_value={}, force=True) - self._get_ansible_param(obj['range'], 'range_from', fname='from', force=True) - self._get_ansible_param(obj['range'], 'range_to', fname='to', force=True) + if params["state"] == "present": + self._get_ansible_param(obj, "range", force_value={}, force=True) + self._get_ansible_param( + obj["range"], "range_from", fname="from", force=True + ) + self._get_ansible_param(obj["range"], "range_to", fname="to", force=True) # Forced options - for option in ['failover_peerip', 'defaultleasetime', 'maxleasetime', - 'netmask', 'gateway', 'domain', 'domainsearchlist', - 'ddnsdomain', 'ddnsdomainprimary', 'ddnsdomainkeyname', - 'ddnsdomainkeyalgorithm', 'ddnsdomainkey', 'mac_allow', - 'mac_deny', 'ddnsclientupdates', 'tftp', 'ldap', - 'nextserver', 'filename', 'filename32', 'filename64', - 'rootpath', 'numberoptions']: + for option in [ + "failover_peerip", + "defaultleasetime", + "maxleasetime", + "netmask", + "gateway", + "domain", + "domainsearchlist", + "ddnsdomain", + "ddnsdomainprimary", + "ddnsdomainkeyname", + "ddnsdomainkeyalgorithm", + "ddnsdomainkey", + "mac_allow", + "mac_deny", + "ddnsclientupdates", + "tftp", + "ldap", + "nextserver", + "filename", + "filename32", + "filename64", + "rootpath", + "numberoptions", + ]: self._get_ansible_param(obj, option, force=True) - for option in ['mac_allow', 'mac_deny']: + for option in ["mac_allow", "mac_deny"]: if params[option] is None: params[option] = "" - self._get_ansible_param(obj, ','.join(params[option])) + self._get_ansible_param(obj, ",".join(params[option])) # Non-forced options - for option in ['winsserver', 'dnsserver', 'ntpserver']: + for option in ["winsserver", "dnsserver", "ntpserver"]: self._get_ansible_param(obj, option) - for option in ['enable', 'ignorebootp', 'nonak', 'ignoreclientuids', - 'staticarp', 'disablepingcheck']: - self._get_ansible_param_bool(obj, option, value='') + for option in [ + "enable", + "ignorebootp", + "nonak", + "ignoreclientuids", + "staticarp", + "disablepingcheck", + ]: + self._get_ansible_param_bool(obj, option, value="") - for option in ['dhcpinlocaltime', 'statsgraph']: - self._get_ansible_param_bool(obj, option, value='yes') + for option in ["dhcpinlocaltime", "statsgraph"]: + self._get_ansible_param_bool(obj, option, value="yes") - self._get_ansible_param(obj, 'denyunknown') - if obj.get('denyunknown') == 'disabled': - del obj['denyunknown'] + self._get_ansible_param(obj, "denyunknown") + if obj.get("denyunknown") == "disabled": + del obj["denyunknown"] # Defaulted options - self._get_ansible_param(obj, 'ddnsdomainkeyalgorithm', force_value='hmac-md5', force=True) + self._get_ansible_param( + obj, "ddnsdomainkeyalgorithm", force_value="hmac-md5", force=True + ) return obj @@ -164,55 +214,89 @@ def _validate_params(self): """do some extra checks on input parameters""" params = self.params - self.target = self._get_logical_interface(params['interface']) + self.target = self._get_logical_interface(params["interface"]) if self.target is None or self.target.lower() == "wan": - self.module.fail_json(msg=f"The specified interface {params['interface']} is not a valid logical interface or cannot be mapped to one") + self.module.fail_json( + msg=f"The specified interface {params['interface']} is not a valid logical interface or cannot be mapped to one" + ) if not self._is_valid_netif(self.target): - self.module.fail_json(msg=f"The specified interface {params['interface']} is not a valid logical interface") - - if params['state'] == 'present' and params['enable']: - if params.get('range_from') is None or params.get('range_to') is None: - self.module.fail_json(msg=f"The specified interface {params['interface']}'requires an IP range") - - if not self.pfsense.is_ipv4_address(params['range_from']): - self.module.fail_json(msg="The 'range_from' address is not a valid IPv4 address") - if not self.pfsense.is_ipv4_address(params['range_to']): - self.module.fail_json(msg="The 'range_to' address is not a valid IPv4 address") - - if not ip_address(params['range_from']) in self.network or not ip_address(params['range_to']) in self.network: - self.module.fail_json(msg=f"The IP address must lie in the {params['interface']} subnet") - - if ip_address(params['range_from']) >= ip_address(params['range_to']): - self.module.fail_json(msg=f"The interface {params['interface']} must have a valid IP range pool") - - if params.get('gateway'): - if not self.pfsense.is_ipv4_address(params['gateway']): - self.module.fail_json(msg="The 'gateway' is not a valid IPv4 address") - - if params.get('mac_allow'): + self.module.fail_json( + msg=f"The specified interface {params['interface']} is not a valid logical interface" + ) + + if params["state"] == "present" and params["enable"]: + if params.get("range_from") is None or params.get("range_to") is None: + self.module.fail_json( + msg=f"The specified interface {params['interface']}'requires an IP range" + ) + + if not self.pfsense.is_ipv4_address(params["range_from"]): + self.module.fail_json( + msg="The 'range_from' address is not a valid IPv4 address" + ) + if not self.pfsense.is_ipv4_address(params["range_to"]): + self.module.fail_json( + msg="The 'range_to' address is not a valid IPv4 address" + ) + + if ( + ip_address(params["range_from"]) not in self.network + or ip_address(params["range_to"]) not in self.network + ): + self.module.fail_json( + msg=f"The IP address must lie in the {params['interface']} subnet" + ) + + if ip_address(params["range_from"]) >= ip_address(params["range_to"]): + self.module.fail_json( + msg=f"The interface {params['interface']} must have a valid IP range pool" + ) + + if params.get("gateway"): + if not self.pfsense.is_ipv4_address(params["gateway"]): + self.module.fail_json( + msg="The 'gateway' is not a valid IPv4 address" + ) + + if params.get("mac_allow"): for macaddr in params["mac_allow"]: is_valid = self._is_valid_macaddr(macaddr) if not is_valid: - self.module.fail_json(msg=f"The MAC address {macaddr} is invalid") + self.module.fail_json( + msg=f"The MAC address {macaddr} is invalid" + ) - if params.get('mac_deny'): + if params.get("mac_deny"): for macaddr in params["mac_deny"]: is_valid = self._is_valid_macaddr(macaddr) if not is_valid: - self.module.fail_json(msg=f"The MAC address {macaddr} is invalid") + self.module.fail_json( + msg=f"The MAC address {macaddr} is invalid" + ) - if params.get('denyunknown') not in [None, 'disabled', 'enabled', 'class']: - self.module.fail_json(msg=f"The option {params['denyunknown']} is invalid, use 'disabled', 'enabled' or 'class'") + if params.get("denyunknown") not in [None, "disabled", "enabled", "class"]: + self.module.fail_json( + msg=f"The option {params['denyunknown']} is invalid, use 'disabled', 'enabled' or 'class'" + ) ############################## # XML processing # def _get_params_to_remove(self): """returns the list of params to remove if they are not set""" - params = ['enable', 'ignorebootp', 'nonak', 'ignoreclientuids', 'staticarp', 'disablepingcheck', 'dhcpinlocaltime', 'statsgraph'] - if self.params.get('denyunknown') == 'disabled': - params.append('denyunknown') + params = [ + "enable", + "ignorebootp", + "nonak", + "ignoreclientuids", + "staticarp", + "disablepingcheck", + "dhcpinlocaltime", + "statsgraph", + ] + if self.params.get("denyunknown") == "disabled": + params.append("denyunknown") return params def _create_target(self): @@ -232,63 +316,83 @@ def _get_obj_name(self): def _log_fields(self, before=None): """generate pseudo-CLI command fields parameters to create an obj""" - values = '' + values = "" if before is None: - values += self.format_cli_field(self.obj, 'enable', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.obj["range"], 'from', fname="range_from") - values += self.format_cli_field(self.obj["range"], 'to', fname="range_to") - values += self.format_cli_field(self.obj, 'failover_peerip') - values += self.format_cli_field(self.obj, 'defaultleasetime') - values += self.format_cli_field(self.obj, 'maxleasetime') - values += self.format_cli_field(self.obj, 'netmask') - values += self.format_cli_field(self.obj, 'gateway') - values += self.format_cli_field(self.obj, 'domain') - values += self.format_cli_field(self.obj, 'domainsearchlist') - values += self.format_cli_field(self.obj, 'ddnsdomain') - values += self.format_cli_field(self.obj, 'ddnsdomainprimary') - values += self.format_cli_field(self.obj, 'ddnsdomainkeyname') - values += self.format_cli_field(self.obj, 'ddnsdomainkeyalgorithm') - values += self.format_cli_field(self.obj, 'ddnsdomainkey') - values += self.format_cli_field(self.obj, 'mac_allow') - values += self.format_cli_field(self.obj, 'mac_deny') - values += self.format_cli_field(self.obj, 'ddnsclientupdates') - values += self.format_cli_field(self.obj, 'tftp') - values += self.format_cli_field(self.obj, 'ldap') - values += self.format_cli_field(self.obj, 'nextserver') - values += self.format_cli_field(self.obj, 'filename') - values += self.format_cli_field(self.obj, 'filename32') - values += self.format_cli_field(self.obj, 'filename64') - values += self.format_cli_field(self.obj, 'rootpath') - values += self.format_cli_field(self.obj, 'numberoptions') - values += self.format_cli_field(self.obj, 'denyunknown') + values += self.format_cli_field(self.obj, "enable", fvalue=self.fvalue_bool) + values += self.format_cli_field( + self.obj["range"], "from", fname="range_from" + ) + values += self.format_cli_field(self.obj["range"], "to", fname="range_to") + values += self.format_cli_field(self.obj, "failover_peerip") + values += self.format_cli_field(self.obj, "defaultleasetime") + values += self.format_cli_field(self.obj, "maxleasetime") + values += self.format_cli_field(self.obj, "netmask") + values += self.format_cli_field(self.obj, "gateway") + values += self.format_cli_field(self.obj, "domain") + values += self.format_cli_field(self.obj, "domainsearchlist") + values += self.format_cli_field(self.obj, "ddnsdomain") + values += self.format_cli_field(self.obj, "ddnsdomainprimary") + values += self.format_cli_field(self.obj, "ddnsdomainkeyname") + values += self.format_cli_field(self.obj, "ddnsdomainkeyalgorithm") + values += self.format_cli_field(self.obj, "ddnsdomainkey") + values += self.format_cli_field(self.obj, "mac_allow") + values += self.format_cli_field(self.obj, "mac_deny") + values += self.format_cli_field(self.obj, "ddnsclientupdates") + values += self.format_cli_field(self.obj, "tftp") + values += self.format_cli_field(self.obj, "ldap") + values += self.format_cli_field(self.obj, "nextserver") + values += self.format_cli_field(self.obj, "filename") + values += self.format_cli_field(self.obj, "filename32") + values += self.format_cli_field(self.obj, "filename64") + values += self.format_cli_field(self.obj, "rootpath") + values += self.format_cli_field(self.obj, "numberoptions") + values += self.format_cli_field(self.obj, "denyunknown") else: - values += self.format_updated_cli_field(self.obj, before, 'enable', fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj["range"], before["range"], 'from', fname="range_from") - values += self.format_updated_cli_field(self.obj["range"], before["range"], 'to', fname="range_to") - values += self.format_updated_cli_field(self.obj, before, 'failover_peerip') - values += self.format_updated_cli_field(self.obj, before, 'defaultleasetime') - values += self.format_updated_cli_field(self.obj, before, 'maxleasetime') - values += self.format_updated_cli_field(self.obj, before, 'netmask') - values += self.format_updated_cli_field(self.obj, before, 'gateway') - values += self.format_updated_cli_field(self.obj, before, 'domain') - values += self.format_updated_cli_field(self.obj, before, 'domainsearchlist') - values += self.format_updated_cli_field(self.obj, before, 'ddnsdomain') - values += self.format_updated_cli_field(self.obj, before, 'ddnsdomainprimary') - values += self.format_updated_cli_field(self.obj, before, 'ddnsdomainkeyname') - values += self.format_updated_cli_field(self.obj, before, 'ddnsdomainkeyalgorithm') - values += self.format_updated_cli_field(self.obj, before, 'ddnsdomainkey') - values += self.format_updated_cli_field(self.obj, before, 'mac_allow') - values += self.format_updated_cli_field(self.obj, before, 'mac_deny') - values += self.format_updated_cli_field(self.obj, before, 'ddnsclientupdates') - values += self.format_updated_cli_field(self.obj, before, 'tftp') - values += self.format_updated_cli_field(self.obj, before, 'ldap') - values += self.format_updated_cli_field(self.obj, before, 'nextserver') - values += self.format_updated_cli_field(self.obj, before, 'filename') - values += self.format_updated_cli_field(self.obj, before, 'filename32') - values += self.format_updated_cli_field(self.obj, before, 'filename64') - values += self.format_updated_cli_field(self.obj, before, 'rootpath') - values += self.format_updated_cli_field(self.obj, before, 'numberoptions') - values += self.format_updated_cli_field(self.obj, before, 'denyunknown') + values += self.format_updated_cli_field( + self.obj, before, "enable", fvalue=self.fvalue_bool + ) + values += self.format_updated_cli_field( + self.obj["range"], before["range"], "from", fname="range_from" + ) + values += self.format_updated_cli_field( + self.obj["range"], before["range"], "to", fname="range_to" + ) + values += self.format_updated_cli_field(self.obj, before, "failover_peerip") + values += self.format_updated_cli_field( + self.obj, before, "defaultleasetime" + ) + values += self.format_updated_cli_field(self.obj, before, "maxleasetime") + values += self.format_updated_cli_field(self.obj, before, "netmask") + values += self.format_updated_cli_field(self.obj, before, "gateway") + values += self.format_updated_cli_field(self.obj, before, "domain") + values += self.format_updated_cli_field( + self.obj, before, "domainsearchlist" + ) + values += self.format_updated_cli_field(self.obj, before, "ddnsdomain") + values += self.format_updated_cli_field( + self.obj, before, "ddnsdomainprimary" + ) + values += self.format_updated_cli_field( + self.obj, before, "ddnsdomainkeyname" + ) + values += self.format_updated_cli_field( + self.obj, before, "ddnsdomainkeyalgorithm" + ) + values += self.format_updated_cli_field(self.obj, before, "ddnsdomainkey") + values += self.format_updated_cli_field(self.obj, before, "mac_allow") + values += self.format_updated_cli_field(self.obj, before, "mac_deny") + values += self.format_updated_cli_field( + self.obj, before, "ddnsclientupdates" + ) + values += self.format_updated_cli_field(self.obj, before, "tftp") + values += self.format_updated_cli_field(self.obj, before, "ldap") + values += self.format_updated_cli_field(self.obj, before, "nextserver") + values += self.format_updated_cli_field(self.obj, before, "filename") + values += self.format_updated_cli_field(self.obj, before, "filename32") + values += self.format_updated_cli_field(self.obj, before, "filename64") + values += self.format_updated_cli_field(self.obj, before, "rootpath") + values += self.format_updated_cli_field(self.obj, before, "numberoptions") + values += self.format_updated_cli_field(self.obj, before, "denyunknown") return values ############################## @@ -296,15 +400,17 @@ def _log_fields(self, before=None): # def _update(self): """make the target pfsense reload""" - return self.pfsense.phpshell(""" + return self.pfsense.phpshell( + """ require_once("util.inc"); require_once("services.inc"); services_dhcpd_configure(); - """) + """ + ) def _pre_remove_target_elt(self): - self.diff['after'] = {} + self.diff["after"] = {} if self.target_elt is not None: - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) else: - self.diff['before'] = {} + self.diff["before"] = {} diff --git a/plugins/module_utils/gateway.py b/plugins/module_utils/gateway.py index 0761970c..9815c8f4 100644 --- a/plugins/module_utils/gateway.py +++ b/plugins/module_utils/gateway.py @@ -4,24 +4,27 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) from ipaddress import ip_address, ip_network GATEWAY_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - name=dict(required=True, type='str'), - interface=dict(required=False, type='str'), - ipprotocol=dict(default='inet', choices=['inet', 'inet6']), - gateway=dict(required=False, type='str'), - descr=dict(default='', type='str'), - disabled=dict(default=False, type='bool'), - monitor=dict(required=False, type='str'), - monitor_disable=dict(default=False, type='bool'), - action_disable=dict(default=False, type='bool'), - force_down=dict(default=False, type='bool'), - weight=dict(default=1, required=False, type='int'), - nonlocalgateway=dict(default=False, type='bool'), + state=dict(default="present", choices=["present", "absent"]), + name=dict(required=True, type="str"), + interface=dict(required=False, type="str"), + ipprotocol=dict(default="inet", choices=["inet", "inet6"]), + gateway=dict(required=False, type="str"), + descr=dict(default="", type="str"), + disabled=dict(default=False, type="bool"), + monitor=dict(required=False, type="str"), + monitor_disable=dict(default=False, type="bool"), + action_disable=dict(default=False, type="bool"), + force_down=dict(default=False, type="bool"), + weight=dict(default=1, required=False, type="int"), + nonlocalgateway=dict(default=False, type="bool"), ) GATEWAY_REQUIRED_IF = [ @@ -30,18 +33,25 @@ class PFSenseGatewayModule(PFSenseModuleBase): - """ module managing pfsense gateways """ + """module managing pfsense gateways""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return GATEWAY_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseGatewayModule, self).__init__(module, pfsense, root='gateways', create_root=True, node='gateway_item', key='name') + super(PFSenseGatewayModule, self).__init__( + module, + pfsense, + root="gateways", + create_root=True, + node="gateway_item", + key="name", + ) self.name = "pfsense_gateway" self.interface_elt = None self.dynamic = False @@ -50,148 +60,214 @@ def __init__(self, module, pfsense=None): # params processing # def _check_gateway_groups(self): - """ check if gateway is in use in gateway groups """ + """check if gateway is in use in gateway groups""" for elt in self.root_elt: - if (elt.tag == 'defaultgw4' or elt.tag == 'defaultgw6') and (elt.text is not None and elt.text == self.params['name']): + if (elt.tag == "defaultgw4" or elt.tag == "defaultgw6") and ( + elt.text is not None and elt.text == self.params["name"] + ): return False - if elt.tag != 'gateway_group': + if elt.tag != "gateway_group": continue - items = elt.findall('.//item') + items = elt.findall(".//item") for item in items: - fields = item.text.split('|') - if fields and fields[0] == self.params['name']: + fields = item.text.split("|") + if fields and fields[0] == self.params["name"]: return False return True def _check_routes(self): - """ check if gateway is in use in static routes """ - routes = self.pfsense.get_element('staticroutes') + """check if gateway is in use in static routes""" + routes = self.pfsense.get_element("staticroutes") if routes is None: return True for elt in routes: - if elt.find('gateway').text == self.params['name']: + if elt.find("gateway").text == self.params["name"]: return False return True def _check_subnet(self): - """ check if addr lies into interface subnets """ + """check if addr lies into interface subnets""" + def _check_vips(): - virtualips = self.pfsense.get_element('virtualip') + virtualips = self.pfsense.get_element("virtualip") if virtualips is None: return False for vip_elt in virtualips: - if vip_elt.find('interface').text != self.interface_elt.tag or vip_elt.find('mode').text != 'other' or vip_elt.find('type').text != 'network': + if ( + vip_elt.find("interface").text != self.interface_elt.tag + or vip_elt.find("mode").text != "other" + or vip_elt.find("type").text != "network" + ): continue - subnet = ip_network(u'{0}/{1}'.format(vip_elt.find('subnet').text, vip_elt.find('subnet_bits').text), strict=False) + subnet = ip_network( + "{0}/{1}".format( + vip_elt.find("subnet").text, vip_elt.find("subnet_bits").text + ), + strict=False, + ) if addr in subnet: return True return False - if self.params['ipprotocol'] == 'inet': - inet_type = 'IPv4' - f1_elt = self.interface_elt.find('ipaddr') - f2_elt = self.interface_elt.find('subnet') + if self.params["ipprotocol"] == "inet": + inet_type = "IPv4" + f1_elt = self.interface_elt.find("ipaddr") + f2_elt = self.interface_elt.find("subnet") else: - inet_type = 'IPv6' - f1_elt = self.interface_elt.find('ipaddrv6') - f2_elt = self.interface_elt.find('subnetv6') - if f1_elt is None or f1_elt.text is None or f2_elt is None or f2_elt.text is None: - self.module.fail_json(msg='Cannot add {0} Gateway Address because no {0} address could be found on the interface.'.format(inet_type)) + inet_type = "IPv6" + f1_elt = self.interface_elt.find("ipaddrv6") + f2_elt = self.interface_elt.find("subnetv6") + if ( + f1_elt is None + or f1_elt.text is None + or f2_elt is None + or f2_elt.text is None + ): + self.module.fail_json( + msg="Cannot add {0} Gateway Address because no {0} address could be found on the interface.".format( + inet_type + ) + ) try: - if self.params['nonlocalgateway']: + if self.params["nonlocalgateway"]: return - addr = ip_address(u'{0}'.format(self.params['gateway'])) - subnet = ip_network(u'{0}/{1}'.format(f1_elt.text, f2_elt.text), strict=False) + addr = ip_address("{0}".format(self.params["gateway"])) + subnet = ip_network( + "{0}/{1}".format(f1_elt.text, f2_elt.text), strict=False + ) if addr in subnet or _check_vips(): return - self.module.fail_json(msg="The gateway address {0} does not lie within one of the chosen interface's subnets.".format(self.params['gateway'])) + self.module.fail_json( + msg="The gateway address {0} does not lie within one of the chosen interface's subnets.".format( + self.params["gateway"] + ) + ) except ValueError: - self.module.fail_json(msg='Cannot add {0} Gateway Address because no {0} address could be found on the interface.'.format(inet_type)) + self.module.fail_json( + msg="Cannot add {0} Gateway Address because no {0} address could be found on the interface.".format( + inet_type + ) + ) def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() - obj['name'] = params['name'] - if params['state'] == 'present': - obj['interface'] = self.pfsense.parse_interface(params['interface']) - self.interface_elt = self.pfsense.get_interface_elt(obj['interface']) - self._get_ansible_param(obj, 'ipprotocol') - self._get_ansible_param(obj, 'gateway') - self._get_ansible_param(obj, 'descr') - self._get_ansible_param(obj, 'monitor') - self._get_ansible_param(obj, 'weight') - - self._get_ansible_param_bool(obj, 'disabled', value=None) - self._get_ansible_param_bool(obj, 'monitor_disable', value=None) - self._get_ansible_param_bool(obj, 'action_disable', value=None) - self._get_ansible_param_bool(obj, 'force_down', value=None) - self._get_ansible_param_bool(obj, 'nonlocalgateway', value=None) + obj["name"] = params["name"] + if params["state"] == "present": + obj["interface"] = self.pfsense.parse_interface(params["interface"]) + self.interface_elt = self.pfsense.get_interface_elt(obj["interface"]) + self._get_ansible_param(obj, "ipprotocol") + self._get_ansible_param(obj, "gateway") + self._get_ansible_param(obj, "descr") + self._get_ansible_param(obj, "monitor") + self._get_ansible_param(obj, "weight") + + self._get_ansible_param_bool(obj, "disabled", value=None) + self._get_ansible_param_bool(obj, "monitor_disable", value=None) + self._get_ansible_param_bool(obj, "action_disable", value=None) + self._get_ansible_param_bool(obj, "force_down", value=None) + self._get_ansible_param_bool(obj, "nonlocalgateway", value=None) if not self.dynamic: self._check_subnet() - elif self.target_elt.find('interface').text != obj['interface']: - self.module.fail_json(msg="The gateway use 'dynamic' as a target. You can not change the interface") - elif self.target_elt.find('ipprotocol').text != params['ipprotocol']: - self.module.fail_json(msg="The gateway use 'dynamic' as a target. You can not change ipprotocol") + elif self.target_elt.find("interface").text != obj["interface"]: + self.module.fail_json( + msg="The gateway use 'dynamic' as a target. You can not change the interface" + ) + elif self.target_elt.find("ipprotocol").text != params["ipprotocol"]: + self.module.fail_json( + msg="The gateway use 'dynamic' as a target. You can not change ipprotocol" + ) return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - self.target_elt = self.pfsense.find_gateway_elt(params['name'], dhcp=True, vti=True) - if self.target_elt is not None and self.target_elt.find('gateway').text == 'dynamic': + self.target_elt = self.pfsense.find_gateway_elt( + params["name"], dhcp=True, vti=True + ) + if ( + self.target_elt is not None + and self.target_elt.find("gateway").text == "dynamic" + ): self.dynamic = True - if params['state'] == 'present': + if params["state"] == "present": # check weight - if params.get('weight') is not None and (params['weight'] < 1 or params['weight'] > 30): - self.module.fail_json(msg='weight must be between 1 and 30') + if params.get("weight") is not None and ( + params["weight"] < 1 or params["weight"] > 30 + ): + self.module.fail_json(msg="weight must be between 1 and 30") if self.dynamic: - if params['gateway'] != 'dynamic': - self.module.fail_json(msg="The gateway use 'dynamic' as a target. This is read-only, so you must set gateway as dynamic too") + if params["gateway"] != "dynamic": + self.module.fail_json( + msg="The gateway use 'dynamic' as a target. This is read-only, so you must set gateway as dynamic too" + ) else: - self.pfsense.check_ip_address(params['gateway'], params['ipprotocol'], 'gateway', fail_ifnotip=True) - if params.get('monitor') is not None and params['monitor'] != '': - self.pfsense.check_ip_address(params['monitor'], params['ipprotocol'], 'monitor', fail_ifnotip=True) - - self.pfsense.check_name(params['name'], 'gateway') + self.pfsense.check_ip_address( + params["gateway"], + params["ipprotocol"], + "gateway", + fail_ifnotip=True, + ) + if params.get("monitor") is not None and params["monitor"] != "": + self.pfsense.check_ip_address( + params["monitor"], + params["ipprotocol"], + "monitor", + fail_ifnotip=True, + ) + + self.pfsense.check_name(params["name"], "gateway") else: if self.dynamic: - self.module.fail_json(msg="The gateway use 'dynamic' as a target. You can not delete it") + self.module.fail_json( + msg="The gateway use 'dynamic' as a target. You can not delete it" + ) if not self._check_gateway_groups() or not self._check_routes(): - self.module.fail_json(msg="The gateway is still in use. You can not delete it") + self.module.fail_json( + msg="The gateway is still in use. You can not delete it" + ) ############################## # XML processing # @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - return ['disabled', 'monitor', 'monitor_disable', 'action_disable', 'force_down', 'nonlocalgateway'] + """returns the list of params to remove if they are not set""" + return [ + "disabled", + "monitor", + "monitor_disable", + "action_disable", + "force_down", + "nonlocalgateway", + ] ############################## # run # def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell(''' + """make the target pfsense reload""" + return self.pfsense.phpshell( + """ require_once("filter.inc"); $retval = 0; @@ -204,44 +280,99 @@ def _update(self): send_event("service reload dyndnsall"); if ($retval == 0) clear_subsystem_dirty('staticroutes'); -''') +""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.obj['name']) + """return obj's name""" + return "'{0}'".format(self.obj["name"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'interface') - values += self.format_cli_field(self.obj, 'ipprotocol', default='inet') - values += self.format_cli_field(self.obj, 'gateway') - values += self.format_cli_field(self.obj, 'descr', default='') - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.obj, 'monitor') - values += self.format_cli_field(self.params, 'monitor_disable', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'action_disable', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'force_down', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.obj, 'weight', default='1') - values += self.format_cli_field(self.params, 'nonlocalgateway', fvalue=self.fvalue_bool, default=False) + values += self.format_cli_field(self.params, "interface") + values += self.format_cli_field(self.obj, "ipprotocol", default="inet") + values += self.format_cli_field(self.obj, "gateway") + values += self.format_cli_field(self.obj, "descr", default="") + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.obj, "monitor") + values += self.format_cli_field( + self.params, "monitor_disable", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field( + self.params, "action_disable", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field( + self.params, "force_down", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.obj, "weight", default="1") + values += self.format_cli_field( + self.params, "nonlocalgateway", fvalue=self.fvalue_bool, default=False + ) else: fbefore = dict() - fbefore['interface'] = self.pfsense.get_interface_display_name(before['interface']) - - values += self.format_updated_cli_field(self.params, fbefore, 'interface', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'ipprotocol', default='inet', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'gateway', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'descr', default='', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'disabled', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'monitor', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'monitor_disable', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'action_disable', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'force_down', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'weight', default='1', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'nonlocalgateway', fvalue=self.fvalue_bool) + fbefore["interface"] = self.pfsense.get_interface_display_name( + before["interface"] + ) + + values += self.format_updated_cli_field( + self.params, fbefore, "interface", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "ipprotocol", default="inet", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "gateway", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "descr", default="", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "disabled", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, before, "monitor", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "monitor_disable", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "action_disable", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "force_down", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, before, "weight", default="1", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "nonlocalgateway", fvalue=self.fvalue_bool + ) return values diff --git a/plugins/module_utils/haproxy_backend.py b/plugins/module_utils/haproxy_backend.py index 02a2d475..154c5153 100644 --- a/plugins/module_utils/haproxy_backend.py +++ b/plugins/module_utils/haproxy_backend.py @@ -4,37 +4,61 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) HAPROXY_BACKEND_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - name=dict(required=True, type='str'), - balance=dict(default='none', choices=['none', 'roundrobin', 'static-rr', 'leastconn', 'source', 'uri']), - balance_urilen=dict(required=False, type='int'), - balance_uridepth=dict(required=False, type='int'), - balance_uriwhole=dict(required=False, type='bool'), - connection_timeout=dict(required=False, type='int'), - server_timeout=dict(required=False, type='int'), - check_type=dict(default='none', choices=['none', 'Basic', 'HTTP', 'Agent', 'LDAP', 'MySQL', 'PostgreSQL', 'Redis', 'SMTP', 'ESMTP', 'SSL']), - check_frequency=dict(required=False, type='int'), - retries=dict(required=False, type='int'), - log_checks=dict(required=False, type='bool'), - httpcheck_method=dict(required=False, choices=['OPTIONS', 'HEAD', 'GET', 'POST', 'PUT', 'DELETE', 'TRACE']), - monitor_uri=dict(required=False, type='str'), - monitor_httpversion=dict(required=False, type='str'), - monitor_username=dict(required=False, type='str'), - monitor_domain=dict(required=False, type='str'), + state=dict(default="present", choices=["present", "absent"]), + name=dict(required=True, type="str"), + balance=dict( + default="none", + choices=["none", "roundrobin", "static-rr", "leastconn", "source", "uri"], + ), + balance_urilen=dict(required=False, type="int"), + balance_uridepth=dict(required=False, type="int"), + balance_uriwhole=dict(required=False, type="bool"), + connection_timeout=dict(required=False, type="int"), + server_timeout=dict(required=False, type="int"), + check_type=dict( + default="none", + choices=[ + "none", + "Basic", + "HTTP", + "Agent", + "LDAP", + "MySQL", + "PostgreSQL", + "Redis", + "SMTP", + "ESMTP", + "SSL", + ], + ), + check_frequency=dict(required=False, type="int"), + retries=dict(required=False, type="int"), + log_checks=dict(required=False, type="bool"), + httpcheck_method=dict( + required=False, + choices=["OPTIONS", "HEAD", "GET", "POST", "PUT", "DELETE", "TRACE"], + ), + monitor_uri=dict(required=False, type="str"), + monitor_httpversion=dict(required=False, type="str"), + monitor_username=dict(required=False, type="str"), + monitor_domain=dict(required=False, type="str"), ) class PFSenseHaproxyBackendModule(PFSenseModuleBase): - """ module managing pfsense haproxy backends """ + """module managing pfsense haproxy backends""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return HAPROXY_BACKEND_ARGUMENT_SPEC ############################## @@ -45,69 +69,77 @@ def __init__(self, module, pfsense=None): self.name = "pfsense_haproxy_backend" self.obj = dict() - pkgs_elt = self.pfsense.get_element('installedpackages') - self.haproxy = pkgs_elt.find('haproxy') if pkgs_elt is not None else None - self.root_elt = self.haproxy.find('ha_pools') if self.haproxy is not None else None + pkgs_elt = self.pfsense.get_element("installedpackages") + self.haproxy = pkgs_elt.find("haproxy") if pkgs_elt is not None else None + self.root_elt = ( + self.haproxy.find("ha_pools") if self.haproxy is not None else None + ) if self.root_elt is None: - self.module.fail_json(msg='Unable to find backends XML configuration entry. Are you sure haproxy is installed ?') + self.module.fail_json( + msg="Unable to find backends XML configuration entry. Are you sure haproxy is installed ?" + ) ############################## # params processing # def _params_to_obj(self): - """ return a backend dict from module params """ + """return a backend dict from module params""" obj = dict() - obj['name'] = self.params['name'] - if self.params['state'] == 'present': - self._get_ansible_param(obj, 'balance', force=True) - if obj['balance'] == 'none': - obj['balance'] = None - self._get_ansible_param(obj, 'balance_urilen', force=True) - self._get_ansible_param(obj, 'balance_uridepth', force=True) - self._get_ansible_param(obj, 'connection_timeout', force=True) - self._get_ansible_param(obj, 'server_timeout', force=True) - self._get_ansible_param(obj, 'check_type', force=True) - self._get_ansible_param(obj, 'check_frequency', fname='checkinter', force=True) - self._get_ansible_param(obj, 'retries', force=True) - self._get_ansible_param_bool(obj, 'log_checks', fname='log-health-checks', force=True) - self._get_ansible_param_bool(obj, 'balance_uriwhole', force=True) - self._get_ansible_param(obj, 'httpcheck_method', force=True) - self._get_ansible_param(obj, 'monitor_uri', force=True) - self._get_ansible_param(obj, 'monitor_httpversion', force=True) - self._get_ansible_param(obj, 'monitor_username', force=True) - self._get_ansible_param(obj, 'monitor_domain', force=True) + obj["name"] = self.params["name"] + if self.params["state"] == "present": + self._get_ansible_param(obj, "balance", force=True) + if obj["balance"] == "none": + obj["balance"] = None + self._get_ansible_param(obj, "balance_urilen", force=True) + self._get_ansible_param(obj, "balance_uridepth", force=True) + self._get_ansible_param(obj, "connection_timeout", force=True) + self._get_ansible_param(obj, "server_timeout", force=True) + self._get_ansible_param(obj, "check_type", force=True) + self._get_ansible_param( + obj, "check_frequency", fname="checkinter", force=True + ) + self._get_ansible_param(obj, "retries", force=True) + self._get_ansible_param_bool( + obj, "log_checks", fname="log-health-checks", force=True + ) + self._get_ansible_param_bool(obj, "balance_uriwhole", force=True) + self._get_ansible_param(obj, "httpcheck_method", force=True) + self._get_ansible_param(obj, "monitor_uri", force=True) + self._get_ansible_param(obj, "monitor_httpversion", force=True) + self._get_ansible_param(obj, "monitor_username", force=True) + self._get_ansible_param(obj, "monitor_domain", force=True) return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" # check name - if re.search(r'[^a-zA-Z0-9\.\-_]', self.params['name']) is not None: + if re.search(r"[^a-zA-Z0-9\.\-_]", self.params["name"]) is not None: self.module.fail_json(msg="The field 'name' contains invalid characters.") ############################## # XML processing # def _create_target(self): - """ create the XML target_elt """ - server_elt = self.pfsense.new_element('item') - self.obj['id'] = self._get_next_id() + """create the XML target_elt""" + server_elt = self.pfsense.new_element("item") + self.obj["id"] = self._get_next_id() return server_elt def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" for item_elt in self.root_elt: - if item_elt.tag != 'item': + if item_elt.tag != "item": continue - name_elt = item_elt.find('name') - if name_elt is not None and name_elt.text == self.obj['name']: + name_elt = item_elt.find("name") + if name_elt is not None and name_elt.text == self.obj["name"]: return item_elt return None def _get_next_id(self): - """ get next free haproxy id """ + """get next free haproxy id""" max_id = 99 - id_elts = self.haproxy.findall('.//id') + id_elts = self.haproxy.findall(".//id") for id_elt in id_elts: if id_elt.text is None: continue @@ -120,53 +152,102 @@ def _get_next_id(self): # run # def _update(self): - """ make the target pfsense reload haproxy """ - return self.pfsense.phpshell('''require_once("haproxy/haproxy.inc"); -$result = haproxy_check_and_run($savemsg, true); if ($result) unlink_if_exists($d_haproxyconfdirty_path);''') + """make the target pfsense reload haproxy""" + return self.pfsense.phpshell( + """require_once("haproxy/haproxy.inc"); +$result = haproxy_check_and_run($savemsg, true); if ($result) unlink_if_exists($d_haproxyconfdirty_path);""" + ) ############################## # Logging # def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'balance') - values += self.format_cli_field(self.params, 'balance_urilen') - values += self.format_cli_field(self.params, 'balance_uridepth') - values += self.format_cli_field(self.params, 'balance_uriwhole', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'connection_timeout') - values += self.format_cli_field(self.params, 'server_timeout') - values += self.format_cli_field(self.params, 'check_type') - values += self.format_cli_field(self.params, 'check_frequency') - values += self.format_cli_field(self.params, 'retries') - values += self.format_cli_field(self.params, 'log_checks', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'httpcheck_method') - values += self.format_cli_field(self.params, 'monitor_uri') - values += self.format_cli_field(self.params, 'monitor_httpversion') - values += self.format_cli_field(self.params, 'monitor_username') - values += self.format_cli_field(self.params, 'monitor_domain') + values += self.format_cli_field(self.params, "balance") + values += self.format_cli_field(self.params, "balance_urilen") + values += self.format_cli_field(self.params, "balance_uridepth") + values += self.format_cli_field( + self.params, "balance_uriwhole", fvalue=self.fvalue_bool + ) + values += self.format_cli_field(self.params, "connection_timeout") + values += self.format_cli_field(self.params, "server_timeout") + values += self.format_cli_field(self.params, "check_type") + values += self.format_cli_field(self.params, "check_frequency") + values += self.format_cli_field(self.params, "retries") + values += self.format_cli_field( + self.params, "log_checks", fvalue=self.fvalue_bool + ) + values += self.format_cli_field(self.params, "httpcheck_method") + values += self.format_cli_field(self.params, "monitor_uri") + values += self.format_cli_field(self.params, "monitor_httpversion") + values += self.format_cli_field(self.params, "monitor_username") + values += self.format_cli_field(self.params, "monitor_domain") else: - for param in ['balance', 'log-health-checks', 'balance_uriwhole']: - if param in before and before[param] == '': + for param in ["balance", "log-health-checks", "balance_uriwhole"]: + if param in before and before[param] == "": before[param] = None - values += self.format_updated_cli_field(self.obj, before, 'balance', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'balance_urilen', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'balance_uridepth', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'balance_uriwhole', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'connection_timeout', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'server_timeout', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'check_type', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'checkinter', add_comma=(values), fname='check_frequency') - values += self.format_updated_cli_field(self.obj, before, 'retries', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'log-health-checks', add_comma=(values), fname='log_checks', fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'httpcheck_method', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'monitor_uri', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'monitor_httpversion', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'monitor_username', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'monitor_domain', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "balance", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "balance_urilen", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "balance_uridepth", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "balance_uriwhole", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.obj, before, "connection_timeout", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "server_timeout", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "check_type", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "checkinter", + add_comma=(values), + fname="check_frequency", + ) + values += self.format_updated_cli_field( + self.obj, before, "retries", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "log-health-checks", + add_comma=(values), + fname="log_checks", + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.obj, before, "httpcheck_method", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "monitor_uri", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "monitor_httpversion", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "monitor_username", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "monitor_domain", add_comma=(values) + ) return values def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.obj['name']) + """return obj's name""" + return "'{0}'".format(self.obj["name"]) diff --git a/plugins/module_utils/haproxy_backend_server.py b/plugins/module_utils/haproxy_backend_server.py index e6c9979b..ab05d398 100644 --- a/plugins/module_utils/haproxy_backend_server.py +++ b/plugins/module_utils/haproxy_backend_server.py @@ -4,44 +4,47 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) HAPROXY_BACKEND_SERVER_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - backend=dict(required=True, type='str'), - name=dict(required=True, type='str'), - mode=dict(default='active', choices=['active', 'backup', 'disabled', 'inactive']), - forwardto=dict(required=False, type='str'), - address=dict(required=False, type='str'), - port=dict(required=False, type='int'), - ssl=dict(required=False, type='bool'), - checkssl=dict(required=False, type='bool'), - weight=dict(required=False, type='int'), - sslserververify=dict(required=False, type='bool'), - verifyhost=dict(required=False, type='str'), - ca=dict(required=False, type='str'), - crl=dict(required=False, type='str'), - clientcert=dict(required=False, type='str'), - cookie=dict(required=False, type='str'), - maxconn=dict(required=False, type='int'), - advanced=dict(required=False, type='str'), - istemplate=dict(required=False, type='str'), + state=dict(default="present", choices=["present", "absent"]), + backend=dict(required=True, type="str"), + name=dict(required=True, type="str"), + mode=dict(default="active", choices=["active", "backup", "disabled", "inactive"]), + forwardto=dict(required=False, type="str"), + address=dict(required=False, type="str"), + port=dict(required=False, type="int"), + ssl=dict(required=False, type="bool"), + checkssl=dict(required=False, type="bool"), + weight=dict(required=False, type="int"), + sslserververify=dict(required=False, type="bool"), + verifyhost=dict(required=False, type="str"), + ca=dict(required=False, type="str"), + crl=dict(required=False, type="str"), + clientcert=dict(required=False, type="str"), + cookie=dict(required=False, type="str"), + maxconn=dict(required=False, type="int"), + advanced=dict(required=False, type="str"), + istemplate=dict(required=False, type="str"), ) HAPROXY_BACKEND_SERVER_MUTUALLY_EXCLUSIVE = [ - ['forwardto', 'address'], - ['forwardto', 'port'], + ["forwardto", "address"], + ["forwardto", "port"], ] class PFSenseHaproxyBackendServerModule(PFSenseModuleBase): - """ module managing pfsense haproxy backend servers """ + """module managing pfsense haproxy backend servers""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return HAPROXY_BACKEND_SERVER_ARGUMENT_SPEC ############################## @@ -53,11 +56,15 @@ def __init__(self, module, pfsense=None): self.root_elt = None self.obj = dict() - pkgs_elt = self.pfsense.get_element('installedpackages') - self.haproxy = pkgs_elt.find('haproxy') if pkgs_elt is not None else None - self.backends = self.haproxy.find('ha_pools') if self.haproxy is not None else None + pkgs_elt = self.pfsense.get_element("installedpackages") + self.haproxy = pkgs_elt.find("haproxy") if pkgs_elt is not None else None + self.backends = ( + self.haproxy.find("ha_pools") if self.haproxy is not None else None + ) if self.backends is None: - self.module.fail_json(msg='Unable to find backends XML configuration entry. Are you sure haproxy is installed ?') + self.module.fail_json( + msg="Unable to find backends XML configuration entry. Are you sure haproxy is installed ?" + ) self.backend = None self.servers = None @@ -66,120 +73,154 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() - obj['name'] = params['name'] - if params['state'] == 'present': - obj['status'] = params['mode'] + obj["name"] = params["name"] + if params["state"] == "present": + obj["status"] = params["mode"] - for param in ['ssl', 'checkssl', 'sslserververify']: + for param in ["ssl", "checkssl", "sslserververify"]: self._get_ansible_param_bool(obj, param) - self._get_ansible_param(obj, 'forwardto') - self._get_ansible_param(obj, 'address') - self._get_ansible_param(obj, 'port') - self._get_ansible_param(obj, 'weight') - self._get_ansible_param(obj, 'verifyhost') + self._get_ansible_param(obj, "forwardto") + self._get_ansible_param(obj, "address") + self._get_ansible_param(obj, "port") + self._get_ansible_param(obj, "weight") + self._get_ansible_param(obj, "verifyhost") - if 'ca' in params and params['ca'] is not None and params['ca'] != '': - ca_elt = self.pfsense.find_ca_elt(params['ca']) + if "ca" in params and params["ca"] is not None and params["ca"] != "": + ca_elt = self.pfsense.find_ca_elt(params["ca"]) if ca_elt is None: - self.module.fail_json(msg='%s is not a valid certificate authority' % (params['ca'])) - obj['ssl-server-ca'] = ca_elt.find('refid').text + self.module.fail_json( + msg="%s is not a valid certificate authority" % (params["ca"]) + ) + obj["ssl-server-ca"] = ca_elt.find("refid").text - if 'crl' in params and params['crl'] is not None and params['crl'] != '': - crl_elt = self.pfsense.find_crl_elt(params['crl']) + if "crl" in params and params["crl"] is not None and params["crl"] != "": + crl_elt = self.pfsense.find_crl_elt(params["crl"]) if crl_elt is None: - self.module.fail_json(msg='%s is not a valid certificate revocation list' % (params['crl'])) - obj['ssl-server-crl'] = crl_elt.find('refid').text - - if 'clientcert' in params and params['clientcert'] is not None and params['clientcert'] != '': - cert = self.pfsense.find_cert_elt(params['clientcert']) + self.module.fail_json( + msg="%s is not a valid certificate revocation list" + % (params["crl"]) + ) + obj["ssl-server-crl"] = crl_elt.find("refid").text + + if ( + "clientcert" in params + and params["clientcert"] is not None + and params["clientcert"] != "" + ): + cert = self.pfsense.find_cert_elt(params["clientcert"]) if cert is None: - self.module.fail_json(msg='%s is not a valid certificate' % (params['clientcert'])) - obj['ssl-server-clientcert'] = cert.find('refid').text + self.module.fail_json( + msg="%s is not a valid certificate" % (params["clientcert"]) + ) + obj["ssl-server-clientcert"] = cert.find("refid").text - self._get_ansible_param(obj, 'cookie') - self._get_ansible_param(obj, 'maxconn') - self._get_ansible_param(obj, 'advanced') - self._get_ansible_param(obj, 'istemplate') + self._get_ansible_param(obj, "cookie") + self._get_ansible_param(obj, "maxconn") + self._get_ansible_param(obj, "advanced") + self._get_ansible_param(obj, "istemplate") return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - if re.search(r'[^a-zA-Z0-9\.\-_]', params['name']) is not None: + if re.search(r"[^a-zA-Z0-9\.\-_]", params["name"]) is not None: self.module.fail_json(msg="The field 'name' contains invalid characters") - if len(params['name']) < 2: + if len(params["name"]) < 2: self.module.fail_json(msg="The field 'name' must be at least 2 characters") - self.backend = self._find_backend(params['backend']) + self.backend = self._find_backend(params["backend"]) if self.backend is None: - self.module.fail_json(msg="The backend named '{0}' does not exist".format(params['backend'])) + self.module.fail_json( + msg="The backend named '{0}' does not exist".format(params["backend"]) + ) - self.root_elt = self.backend.find('ha_servers') + self.root_elt = self.backend.find("ha_servers") if self.root_elt is None: - self.root_elt = self.pfsense.new_element('ha_servers') + self.root_elt = self.pfsense.new_element("ha_servers") self.backend.append(self.root_elt) - if 'forwardto' in params and params['forwardto'] is not None: + if "forwardto" in params and params["forwardto"] is not None: frontend_elt = None - frontends = self.haproxy.find('ha_backends') + frontends = self.haproxy.find("ha_backends") for item_elt in frontends: - if item_elt.tag != 'item': + if item_elt.tag != "item": continue - name_elt = item_elt.find('name') - if name_elt is not None and name_elt.text == params['forwardto']: + name_elt = item_elt.find("name") + if name_elt is not None and name_elt.text == params["forwardto"]: frontend_elt = item_elt break if frontend_elt is None: - self.module.fail_json(msg="The frontend named '{0}' does not exist".format(params['forwardto'])) + self.module.fail_json( + msg="The frontend named '{0}' does not exist".format( + params["forwardto"] + ) + ) ############################## # XML processing # def _create_target(self): - """ create the XML target_elt """ - server_elt = self.pfsense.new_element('item') - self.obj['id'] = self._get_next_id() + """create the XML target_elt""" + server_elt = self.pfsense.new_element("item") + self.obj["id"] = self._get_next_id() return server_elt def _find_backend(self, name): - """ return the target backend_elt if found """ + """return the target backend_elt if found""" for item_elt in self.backends: - if item_elt.tag != 'item': + if item_elt.tag != "item": continue - name_elt = item_elt.find('name') + name_elt = item_elt.find("name") if name_elt is not None and name_elt.text == name: return item_elt return None def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" for item_elt in self.root_elt: - if item_elt.tag != 'item': + if item_elt.tag != "item": continue - name_elt = item_elt.find('name') - if name_elt is not None and name_elt.text == self.obj['name']: + name_elt = item_elt.find("name") + if name_elt is not None and name_elt.text == self.obj["name"]: return item_elt return None @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - params = ['ssl', 'checkssl', 'sslserververify', 'forwardto', 'address', 'port', 'weight', 'istemplate', 'verifyhost'] - params += ['ssl-server-crl', 'ssl-server-ca', 'ssl-server-clientcert', 'cookie', 'maxconn', 'advanced'] + """returns the list of params to remove if they are not set""" + params = [ + "ssl", + "checkssl", + "sslserververify", + "forwardto", + "address", + "port", + "weight", + "istemplate", + "verifyhost", + ] + params += [ + "ssl-server-crl", + "ssl-server-ca", + "ssl-server-clientcert", + "cookie", + "maxconn", + "advanced", + ] return params def _get_next_id(self): - """ get next free haproxy id """ + """get next free haproxy id""" max_id = 99 - id_elts = self.haproxy.findall('.//id') + id_elts = self.haproxy.findall(".//id") for id_elt in id_elts: if id_elt.text is None: continue @@ -192,78 +233,136 @@ def _get_next_id(self): # run # def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell('''require_once("haproxy/haproxy.inc"); -$result = haproxy_check_and_run($savemsg, true); if ($result) unlink_if_exists($d_haproxyconfdirty_path);''') + """make the target pfsense reload""" + return self.pfsense.phpshell( + """require_once("haproxy/haproxy.inc"); +$result = haproxy_check_and_run($savemsg, true); if ($result) unlink_if_exists($d_haproxyconfdirty_path);""" + ) ############################## # Logging # def _get_ref_names(self, before): - """ get cert and ca names """ - if 'ssl-server-ca' in before and before['ssl-server-ca'] is not None and before['ssl-server-ca'] != '': - elt = self.pfsense.find_ca_elt(before['ssl-server-ca'], 'refid') + """get cert and ca names""" + if ( + "ssl-server-ca" in before + and before["ssl-server-ca"] is not None + and before["ssl-server-ca"] != "" + ): + elt = self.pfsense.find_ca_elt(before["ssl-server-ca"], "refid") if elt is not None: - before['ca'] = elt.find('descr').text - if 'ca' not in before: - before['ca'] = None - - if 'ssl-server-crl' in before and before['ssl-server-crl'] is not None and before['ssl-server-crl'] != '': - elt = self.pfsense.find_crl_elt(before['ssl-server-crl'], 'refid') + before["ca"] = elt.find("descr").text + if "ca" not in before: + before["ca"] = None + + if ( + "ssl-server-crl" in before + and before["ssl-server-crl"] is not None + and before["ssl-server-crl"] != "" + ): + elt = self.pfsense.find_crl_elt(before["ssl-server-crl"], "refid") if elt is not None: - before['crl'] = elt.find('descr').text - if 'crl' not in before: - before['crl'] = None - - if 'ssl-server-clientcert' in before and before['ssl-server-clientcert'] is not None and before['ssl-server-clientcert'] != '': - elt = self.pfsense.find_cert_elt(before['ssl-server-clientcert'], 'refid') + before["crl"] = elt.find("descr").text + if "crl" not in before: + before["crl"] = None + + if ( + "ssl-server-clientcert" in before + and before["ssl-server-clientcert"] is not None + and before["ssl-server-clientcert"] != "" + ): + elt = self.pfsense.find_cert_elt(before["ssl-server-clientcert"], "refid") if elt is not None: - before['clientcert'] = elt.find('descr').text - if 'clientcert' not in before: - before['clientcert'] = None + before["clientcert"] = elt.find("descr").text + if "clientcert" not in before: + before["clientcert"] = None def _get_obj_name(self): - """ return obj's name """ - return "'{0}' on '{1}'".format(self.obj['name'], self.params['backend']) + """return obj's name""" + return "'{0}' on '{1}'".format(self.obj["name"], self.params["backend"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'mode', fname='status') - values += self.format_cli_field(self.params, 'forwardto') - values += self.format_cli_field(self.params, 'address') - values += self.format_cli_field(self.params, 'port') - values += self.format_cli_field(self.params, 'ssl', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'checkssl', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'weight') - values += self.format_cli_field(self.params, 'sslserververify', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'ca') - values += self.format_cli_field(self.params, 'crl') - values += self.format_cli_field(self.params, 'clientcert') - values += self.format_cli_field(self.params, 'cookie') - values += self.format_cli_field(self.params, 'maxconn') - values += self.format_cli_field(self.params, 'advanced') - values += self.format_cli_field(self.params, 'istemplate') + values += self.format_cli_field(self.params, "mode", fname="status") + values += self.format_cli_field(self.params, "forwardto") + values += self.format_cli_field(self.params, "address") + values += self.format_cli_field(self.params, "port") + values += self.format_cli_field(self.params, "ssl", fvalue=self.fvalue_bool) + values += self.format_cli_field( + self.params, "checkssl", fvalue=self.fvalue_bool + ) + values += self.format_cli_field(self.params, "weight") + values += self.format_cli_field( + self.params, "sslserververify", fvalue=self.fvalue_bool + ) + values += self.format_cli_field(self.params, "ca") + values += self.format_cli_field(self.params, "crl") + values += self.format_cli_field(self.params, "clientcert") + values += self.format_cli_field(self.params, "cookie") + values += self.format_cli_field(self.params, "maxconn") + values += self.format_cli_field(self.params, "advanced") + values += self.format_cli_field(self.params, "istemplate") else: - for param in ['ssl', 'checkssl', 'sslserververify']: - if param in before and before[param] == '': + for param in ["ssl", "checkssl", "sslserververify"]: + if param in before and before[param] == "": before[param] = None self._get_ref_names(before) - values += self.format_updated_cli_field(self.obj, before, 'status', add_comma=(values), fname='mode') - values += self.format_updated_cli_field(self.obj, before, 'forwardto', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'address', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'port', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'ssl', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'checkssl', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'weight', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'sslserververify', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'verifyhost', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'ca', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'crl', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'clientcert', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'cookie', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'maxconn', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'advanced', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'istemplate', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "status", add_comma=(values), fname="mode" + ) + values += self.format_updated_cli_field( + self.obj, before, "forwardto", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "address", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "port", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "ssl", add_comma=(values), fvalue=self.fvalue_bool + ) + values += self.format_updated_cli_field( + self.obj, + before, + "checkssl", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.obj, before, "weight", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "sslserververify", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.obj, before, "verifyhost", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "ca", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "crl", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "clientcert", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "cookie", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "maxconn", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "advanced", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "istemplate", add_comma=(values) + ) return values diff --git a/plugins/module_utils/interface.py b/plugins/module_utils/interface.py index cf8e7f26..28bb50ad 100644 --- a/plugins/module_utils/interface.py +++ b/plugins/module_utils/interface.py @@ -5,37 +5,45 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase -from ansible_collections.pfsensible.core.plugins.module_utils.rule import PFSenseRuleModule +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule import ( + PFSenseRuleModule, +) + try: from ipaddress import ip_network except ImportError: - from ansible_collections.community.general.plugins.module_utils.compat.ipaddress import ip_network + from ansible_collections.community.general.plugins.module_utils.compat.ipaddress import ( + ip_network, + ) INTERFACE_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - descr=dict(required=True, type='str'), - interface=dict(required=False, type='str'), - interface_descr=dict(required=False, type='str'), - enable=dict(default=False, type='bool'), - ipv4_type=dict(default='none', choices=['none', 'static', 'dhcp']), - ipv6_type=dict(default='none', choices=['none', 'static', 'slaac']), - mac=dict(required=False, type='str'), - mtu=dict(required=False, type='int'), - mss=dict(required=False, type='int'), - speed_duplex=dict(default='autoselect', required=False, type='str'), - ipv4_address=dict(required=False, type='str'), - ipv4_prefixlen=dict(default=24, required=False, type='int'), - ipv4_gateway=dict(required=False, type='str'), - ipv6_address=dict(required=False, type='str'), - ipv6_prefixlen=dict(default=128, required=False, type='int'), - ipv6_gateway=dict(required=False, type='str'), - blockpriv=dict(required=False, type='bool'), - blockbogons=dict(required=False, type='bool'), - slaacusev4iface=dict(required=False, type='bool'), + state=dict(default="present", choices=["present", "absent"]), + descr=dict(required=True, type="str"), + interface=dict(required=False, type="str"), + interface_descr=dict(required=False, type="str"), + enable=dict(default=False, type="bool"), + ipv4_type=dict(default="none", choices=["none", "static", "dhcp"]), + ipv6_type=dict(default="none", choices=["none", "static", "slaac"]), + mac=dict(required=False, type="str"), + mtu=dict(required=False, type="int"), + mss=dict(required=False, type="int"), + speed_duplex=dict(default="autoselect", required=False, type="str"), + ipv4_address=dict(required=False, type="str"), + ipv4_prefixlen=dict(default=24, required=False, type="int"), + ipv4_gateway=dict(required=False, type="str"), + ipv6_address=dict(required=False, type="str"), + ipv6_prefixlen=dict(default=128, required=False, type="int"), + ipv6_gateway=dict(required=False, type="str"), + blockpriv=dict(required=False, type="bool"), + blockbogons=dict(required=False, type="bool"), + slaacusev4iface=dict(required=False, type="bool"), ) INTERFACE_REQUIRED_IF = [ @@ -44,15 +52,15 @@ ["ipv6_type", "static", ["ipv6_address", "ipv6_prefixlen"]], ] -INTERFACE_MUTUALLY_EXCLUSIVE = [['interface', 'interface_descr']] +INTERFACE_MUTUALLY_EXCLUSIVE = [["interface", "interface_descr"]] class PFSenseInterfaceModule(PFSenseModuleBase): - """ module managing pfsense interfaces """ + """module managing pfsense interfaces""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return INTERFACE_ARGUMENT_SPEC ############################## @@ -73,12 +81,14 @@ def __init__(self, module, pfsense=None): # params processing # def _check_overlaps(self, ipfield, netfield): - """ check new address does not overlaps with one existing """ + """check new address does not overlaps with one existing""" if not self.obj.get(ipfield) or self.obj.get(netfield) is None: return - our_addr = ip_network(u'{0}/{1}'.format(self.obj[ipfield], self.obj[netfield]), strict=False) + our_addr = ip_network( + "{0}/{1}".format(self.obj[ipfield], self.obj[netfield]), strict=False + ) for iface in self.root_elt: if iface == self.target_elt: @@ -86,169 +96,242 @@ def _check_overlaps(self, ipfield, netfield): ipaddr_elt = iface.find(ipfield) subnet_elt = iface.find(netfield) - if ipaddr_elt is None or subnet_elt is None or ipaddr_elt.text in ['dhcp', None] or ipaddr_elt.text in ['dhcpv6', None]: + if ( + ipaddr_elt is None + or subnet_elt is None + or ipaddr_elt.text in ["dhcp", None] + or ipaddr_elt.text in ["dhcpv6", None] + ): continue - other_addr = ip_network(u'{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text), strict=False) + other_addr = ip_network( + "{0}/{1}".format(ipaddr_elt.text, subnet_elt.text), strict=False + ) if our_addr.overlaps(other_addr): - descr_elt = iface.find('descr') + descr_elt = iface.find("descr") if descr_elt is not None and descr_elt.text: ifname = descr_elt.text else: ifname = iface.tag - msg = 'IP address {0}/{1} is being used by or overlaps with: {2} ({3}/{4})'.format( + msg = "IP address {0}/{1} is being used by or overlaps with: {2} ({3}/{4})".format( self.obj[ipfield], self.obj[netfield], ifname, ipaddr_elt.text, - subnet_elt.text + subnet_elt.text, ) self.module.fail_json(msg=msg) def _params_to_obj(self): - """ return an interface dict from module params """ + """return an interface dict from module params""" params = self.params obj = dict() self.obj = obj - obj['descr'] = params['descr'] - if params['state'] == 'present': - obj['if'] = params['interface'] - - for param in ['enable', 'blockpriv', 'blockbogons']: - self._get_ansible_param_bool(obj, param, value='') - - self._get_ansible_param(obj, 'mac', fname='spoofmac', force=True) - self._get_ansible_param(obj, 'mtu') - self._get_ansible_param(obj, 'mss') - self._get_ansible_param(obj, 'speed_duplex', fname='media', exclude='autoselect') - - if params['ipv4_type'] == 'static': - self._get_ansible_param(obj, 'ipv4_address', fname='ipaddr') - self._get_ansible_param(obj, 'ipv4_prefixlen', fname='subnet') - self._get_ansible_param(obj, 'ipv4_gateway', fname='gateway') - elif params['ipv4_type'] == 'dhcp': - obj['ipaddr'] = 'dhcp' - - if params['ipv6_type'] == 'static': - self._get_ansible_param(obj, 'ipv6_address', fname='ipaddrv6') - self._get_ansible_param(obj, 'ipv6_prefixlen', fname='subnetv6') - self._get_ansible_param(obj, 'ipv6_gateway', fname='gatewayv6') - - if params['ipv6_type'] == 'slaac': - obj['ipaddrv6'] = 'slaac' - self._get_ansible_param_bool(obj, 'slaacusev4iface', value='') + obj["descr"] = params["descr"] + if params["state"] == "present": + obj["if"] = params["interface"] + + for param in ["enable", "blockpriv", "blockbogons"]: + self._get_ansible_param_bool(obj, param, value="") + + self._get_ansible_param(obj, "mac", fname="spoofmac", force=True) + self._get_ansible_param(obj, "mtu") + self._get_ansible_param(obj, "mss") + self._get_ansible_param( + obj, "speed_duplex", fname="media", exclude="autoselect" + ) + + if params["ipv4_type"] == "static": + self._get_ansible_param(obj, "ipv4_address", fname="ipaddr") + self._get_ansible_param(obj, "ipv4_prefixlen", fname="subnet") + self._get_ansible_param(obj, "ipv4_gateway", fname="gateway") + elif params["ipv4_type"] == "dhcp": + obj["ipaddr"] = "dhcp" + + if params["ipv6_type"] == "static": + self._get_ansible_param(obj, "ipv6_address", fname="ipaddrv6") + self._get_ansible_param(obj, "ipv6_prefixlen", fname="subnetv6") + self._get_ansible_param(obj, "ipv6_gateway", fname="gatewayv6") + + if params["ipv6_type"] == "slaac": + obj["ipaddrv6"] = "slaac" + self._get_ansible_param_bool(obj, "slaacusev4iface", value="") # get target interface self.target_elt = self._find_matching_interface() - self._check_overlaps('ipaddrv6', 'subnetv6') - self._check_overlaps('ipaddr', 'subnet') + self._check_overlaps("ipaddrv6", "subnetv6") + self._check_overlaps("ipaddr", "subnet") # check gateways - if self.obj.get('gateway') and not self.pfsense.find_gateway_elt(self.obj['gateway'], self.target_elt.tag, 'inet'): - self.module.fail_json(msg='Gateway {0} does not exist on {1}'.format(self.obj['gateway'], self.obj['descr'])) + if self.obj.get("gateway") and not self.pfsense.find_gateway_elt( + self.obj["gateway"], self.target_elt.tag, "inet" + ): + self.module.fail_json( + msg="Gateway {0} does not exist on {1}".format( + self.obj["gateway"], self.obj["descr"] + ) + ) - if self.obj.get('gatewayv6') and not self.pfsense.find_gateway_elt(self.obj['gatewayv6'], self.target_elt.tag, 'inet6'): - self.module.fail_json(msg='Gateway {0} does not exist on {1}'.format(self.obj['gatewayv6'], self.obj['descr'])) + if self.obj.get("gatewayv6") and not self.pfsense.find_gateway_elt( + self.obj["gatewayv6"], self.target_elt.tag, "inet6" + ): + self.module.fail_json( + msg="Gateway {0} does not exist on {1}".format( + self.obj["gatewayv6"], self.obj["descr"] + ) + ) else: - self.target_elt = self._get_interface_elt_by_display_name(self.obj['descr']) + self.target_elt = self._get_interface_elt_by_display_name(self.obj["descr"]) return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - if re.match('^[a-zA-Z0-9_]+$', params['descr']) is None: - self.module.fail_json(msg='The name of the interface may only consist of the characters "a-z, A-Z, 0-9 and _"') - - if params['state'] == 'present': - if params.get('mac') and re.match('^([0-9A-F]{2}[:-]){5}([0-9A-F]{2})$', params['mac']) is None: - self.module.fail_json(msg='MAC address must be in the following format: xx:xx:xx:xx:xx:xx (or blank).') + if re.match("^[a-zA-Z0-9_]+$", params["descr"]) is None: + self.module.fail_json( + msg='The name of the interface may only consist of the characters "a-z, A-Z, 0-9 and _"' + ) + + if params["state"] == "present": + if ( + params.get("mac") + and re.match("^([0-9A-F]{2}[:-]){5}([0-9A-F]{2})$", params["mac"]) + is None + ): + self.module.fail_json( + msg="MAC address must be in the following format: xx:xx:xx:xx:xx:xx (or blank)." + ) # todo can't change mac address on vlan interface - if params.get('ipv4_prefixlen') is not None and params['ipv4_prefixlen'] < 1 or params['ipv4_prefixlen'] > 32: - self.module.fail_json(msg='ipv4_prefixlen must be between 1 and 32.') + if ( + params.get("ipv4_prefixlen") is not None + and params["ipv4_prefixlen"] < 1 + or params["ipv4_prefixlen"] > 32 + ): + self.module.fail_json(msg="ipv4_prefixlen must be between 1 and 32.") - if params.get('ipv6_prefixlen') is not None and params['ipv6_prefixlen'] < 1 or params['ipv6_prefixlen'] > 128: - self.module.fail_json(msg='ipv6_prefixlen must be between 1 and 128.') + if ( + params.get("ipv6_prefixlen") is not None + and params["ipv6_prefixlen"] < 1 + or params["ipv6_prefixlen"] > 128 + ): + self.module.fail_json(msg="ipv6_prefixlen must be between 1 and 128.") - if params.get('mtu') is not None and params['mtu'] < 1: - self.module.fail_json(msg='mtu must be above 0') + if params.get("mtu") is not None and params["mtu"] < 1: + self.module.fail_json(msg="mtu must be above 0") - if params.get('mss') is not None and params['mtu'] < 1: - self.module.fail_json(msg='mtu must be above 0') + if params.get("mss") is not None and params["mtu"] < 1: + self.module.fail_json(msg="mtu must be above 0") interfaces = self._get_interface_list() - if params.get('interface') is not None: - if params['interface'] not in interfaces.keys(): + if params.get("interface") is not None: + if params["interface"] not in interfaces.keys(): self.module.fail_json( - msg='{0} can\'t be assigned. Interface may only be one the following: {1}'.format(params['interface'], list(interfaces.keys()))) - elif params.get('interface_descr') is not None: + msg="{0} can't be assigned. Interface may only be one the following: {1}".format( + params["interface"], list(interfaces.keys()) + ) + ) + elif params.get("interface_descr") is not None: for interface, attributes in interfaces.items(): - if 'descr' in attributes and attributes['descr'] == params['interface_descr']: - if params.get('interface') is not None: - self.module.fail_json(msg='Multiple interfaces found for "{0}"'.format(params['interface_descr'])) + if ( + "descr" in attributes + and attributes["descr"] == params["interface_descr"] + ): + if params.get("interface") is not None: + self.module.fail_json( + msg='Multiple interfaces found for "{0}"'.format( + params["interface_descr"] + ) + ) else: - params['interface'] = interface + params["interface"] = interface else: - self.module.fail_json(msg='one of the following is required: interface, interface_descr') - - media_modes = set(self._get_media_mode(params['interface'])) - media_modes.add('autoselect') - if params.get('speed_duplex') and params['speed_duplex'] not in media_modes: - self.module.fail_json(msg='For this interface, media mode may only be one the following: {0}'.format(media_modes)) + self.module.fail_json( + msg="one of the following is required: interface, interface_descr" + ) - if params['ipv4_type'] == 'static': - if params.get('ipv4_address') and not self.pfsense.is_ipv4_address(params['ipv4_address']): - self.module.fail_json(msg='{0} is not a valid IPv4 address'.format(params['ipv4_address'])) + media_modes = set(self._get_media_mode(params["interface"])) + media_modes.add("autoselect") + if params.get("speed_duplex") and params["speed_duplex"] not in media_modes: + self.module.fail_json( + msg="For this interface, media mode may only be one the following: {0}".format( + media_modes + ) + ) - if params['ipv6_type'] == 'static': - if params.get('ipv6_address') and not self.pfsense.is_ipv6_address(params['ipv6_address']): - self.module.fail_json(msg='{0} is not a valid IPv6 address'.format(params['ipv6_address'])) + if params["ipv4_type"] == "static": + if params.get("ipv4_address") and not self.pfsense.is_ipv4_address( + params["ipv4_address"] + ): + self.module.fail_json( + msg="{0} is not a valid IPv4 address".format( + params["ipv4_address"] + ) + ) + + if params["ipv6_type"] == "static": + if params.get("ipv6_address") and not self.pfsense.is_ipv6_address( + params["ipv6_address"] + ): + self.module.fail_json( + msg="{0} is not a valid IPv6 address".format( + params["ipv6_address"] + ) + ) ############################## # XML processing # def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.setup_interface_cmds += "interface_configure('{0}', true);\n".format(self.target_elt.tag) - self.result['ifname'] = self.target_elt.tag + self.setup_interface_cmds += "interface_configure('{0}', true);\n".format( + self.target_elt.tag + ) + self.result["ifname"] = self.target_elt.tag def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) if self._remove_deleted_params(): changed = True if changed: - if self.params['enable']: - self.setup_interface_cmds += "interface_bring_down('{0}', false);\n".format(self.target_elt.tag) - self.setup_interface_cmds += "interface_configure('{0}', true);\n".format(self.target_elt.tag) + if self.params["enable"]: + self.setup_interface_cmds += ( + "interface_bring_down('{0}', false);\n".format(self.target_elt.tag) + ) + self.setup_interface_cmds += ( + "interface_configure('{0}', true);\n".format(self.target_elt.tag) + ) else: - self.setup_interface_cmds += "interface_bring_down('{0}', true);\n".format(self.target_elt.tag) + self.setup_interface_cmds += ( + "interface_bring_down('{0}', true);\n".format(self.target_elt.tag) + ) - self.result['ifname'] = self.target_elt.tag + self.result["ifname"] = self.target_elt.tag return (before, changed) def _create_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" # wan can't be deleted, so the first interface we can create is lan - if self.pfsense.get_interface_elt('lan') is None: - interface_elt = self.pfsense.new_element('lan') + if self.pfsense.get_interface_elt("lan") is None: + interface_elt = self.pfsense.new_element("lan") self.root_elt.insert(1, interface_elt) return interface_elt # lan is used, so we must create an optX interface i = 1 while True: - interface = 'opt{0}'.format(i) + interface = "opt{0}".format(i) if self.pfsense.get_interface_elt(interface) is None: interface_elt = self.pfsense.new_element(interface) # i + 1 = i + (lan and wan) - 1 @@ -257,19 +340,22 @@ def _create_target(self): i = i + 1 def _get_interface_elt_by_port_and_display_name(self, interface_port, name): - """ return pfsense interface_elt """ + """return pfsense interface_elt""" for iface in self.root_elt: - descr_elt = iface.find('descr') + descr_elt = iface.find("descr") if descr_elt is None: continue - if iface.find('if').text.strip() == interface_port and descr_elt.text.strip().lower() == name.lower(): + if ( + iface.find("if").text.strip() == interface_port + and descr_elt.text.strip().lower() == name.lower() + ): return iface return None def _get_interface_elt_by_display_name(self, name): - """ return pfsense interface by name """ + """return pfsense interface by name""" for iface in self.root_elt: - descr_elt = iface.find('descr') + descr_elt = iface.find("descr") if descr_elt is None: continue if descr_elt.text.strip().lower() == name.lower(): @@ -277,10 +363,10 @@ def _get_interface_elt_by_display_name(self, name): return None def _get_interface_display_name_by_port(self, interface_port): - """ return pfsense interface physical name """ + """return pfsense interface physical name""" for iface in self.root_elt: - if iface.find('if').text.strip() == interface_port: - descr_elt = iface.find('descr') + if iface.find("if").text.strip() == interface_port: + descr_elt = iface.find("descr") if descr_elt is not None: return descr_elt.text.strip() return iface.tag @@ -288,85 +374,120 @@ def _get_interface_display_name_by_port(self, interface_port): return None def _get_interface_elt_by_port(self, interface_port): - """ find pfsense interface by port name """ + """find pfsense interface by port name""" for iface in self.root_elt: - if iface.find('if').text.strip() == interface_port: + if iface.find("if").text.strip() == interface_port: return iface return None def _find_matching_interface(self): - """ return target interface """ + """return target interface""" # we first try to find an interface having same port and display name - interface_elt = self._get_interface_elt_by_port_and_display_name(self.obj['if'], self.obj['descr']) + interface_elt = self._get_interface_elt_by_port_and_display_name( + self.obj["if"], self.obj["descr"] + ) if interface_elt is not None: return interface_elt # we then try to find an existing interface with the same display name - interface_elt = self._get_interface_elt_by_display_name(self.obj['descr']) + interface_elt = self._get_interface_elt_by_display_name(self.obj["descr"]) if interface_elt is not None: # we check the target port can be used - used_by = self._get_interface_display_name_by_port(self.obj['if']) + used_by = self._get_interface_display_name_by_port(self.obj["if"]) if used_by is not None: - self.module.fail_json(msg='Port {0} is already in use on interface {1}'.format(self.obj['if'], used_by)) + self.module.fail_json( + msg="Port {0} is already in use on interface {1}".format( + self.obj["if"], used_by + ) + ) return interface_elt # last, we try to find an existing interface with the port (interface will be renamed) - return self._get_interface_elt_by_port(self.obj['if']) + return self._get_interface_elt_by_port(self.obj["if"]) def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" return self.target_elt @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - params = ['mtu', 'mss', 'gateway', 'enable', 'mac', 'media', 'ipaddr', 'subnet', 'ipaddrv6', 'subnetv6', 'gatewayv6', 'blockpriv', 'blockbogons'] + """returns the list of params to remove if they are not set""" + params = [ + "mtu", + "mss", + "gateway", + "enable", + "mac", + "media", + "ipaddr", + "subnet", + "ipaddrv6", + "subnetv6", + "gatewayv6", + "blockpriv", + "blockbogons", + ] return params def _pre_remove_target_elt(self): - """ processing before removing elt """ - self.obj['if'] = self.target_elt.find('if').text + """processing before removing elt""" + self.obj["if"] = self.target_elt.find("if").text ifname = self.target_elt.tag if self.pfsense.ifgroups is not None: for ifgroup_elt in self.pfsense.ifgroups.findall("ifgroupentry"): - if ifgroup_elt.find('members') is not None: - members = ifgroup_elt.find('members').text.split() + if ifgroup_elt.find("members") is not None: + members = ifgroup_elt.find("members").text.split() if ifname in members: - self.module.fail_json(msg='The interface is part of the group {0}. Please remove it from the group first.'.format( - ifgroup_elt.find('ifname').text)) + self.module.fail_json( + msg="The interface is part of the group {0}. Please remove it from the group first.".format( + ifgroup_elt.find("ifname").text + ) + ) self._remove_all_separators(ifname) self._remove_all_rules(ifname) self.setup_interface_pre_cmds += "interface_bring_down('{0}');\n".format(ifname) - self.result['ifname'] = ifname + self.result["ifname"] = ifname def _remove_all_rules(self, interface): - """ delete all interface rules """ + """delete all interface rules""" # we use the pfsense_rule module to delete the rules since, at least for floating rules, # it implies to recalculate separators positions # if we have to just remove the deleted interface of a floating rule we do it ourselves todel = [] for rule_elt in self.pfsense.rules: - if rule_elt.find('floating') is not None: - interfaces = rule_elt.find('interface').text.split(',') - old_ifs = ','.join([self.pfsense.get_interface_display_name(old_interface) for old_interface in interfaces]) + if rule_elt.find("floating") is not None: + interfaces = rule_elt.find("interface").text.split(",") + old_ifs = ",".join( + [ + self.pfsense.get_interface_display_name(old_interface) + for old_interface in interfaces + ] + ) if interface in interfaces: if len(interfaces) > 1: interfaces.remove(interface) - new_ifs = ','.join([self.pfsense.get_interface_display_name(new_interface) for new_interface in interfaces]) - rule_elt.find('interface').text = ','.join(interfaces) - cmd = 'update rule \'{0}\' on \'floating({1})\' set interface=\'{2}\''.format(rule_elt.find('descr').text, old_ifs, new_ifs) - self.result['commands'].append(cmd) + new_ifs = ",".join( + [ + self.pfsense.get_interface_display_name(new_interface) + for new_interface in interfaces + ] + ) + rule_elt.find("interface").text = ",".join(interfaces) + cmd = "update rule '{0}' on 'floating({1})' set interface='{2}'".format( + rule_elt.find("descr").text, old_ifs, new_ifs + ) + self.result["commands"].append(cmd) continue todel.append(rule_elt) else: continue else: - iface = rule_elt.find('interface') + iface = rule_elt.find("interface") if iface is not None and iface.text == interface: todel.append(rule_elt) @@ -374,27 +495,29 @@ def _remove_all_rules(self, interface): pfsense_rules = PFSenseRuleModule(self.module, self.pfsense) for rule_elt in todel: params = {} - params['state'] = 'absent' - params['name'] = rule_elt.find('descr').text - params['interface'] = rule_elt.find('interface').text - if rule_elt.find('floating') is not None: - params['floating'] = True + params["state"] = "absent" + params["name"] = rule_elt.find("descr").text + params["interface"] = rule_elt.find("interface").text + if rule_elt.find("floating") is not None: + params["floating"] = True pfsense_rules.run(params) - if pfsense_rules.result['commands']: - self.result['commands'].extend(pfsense_rules.result['commands']) + if pfsense_rules.result["commands"]: + self.result["commands"].extend(pfsense_rules.result["commands"]) def _remove_all_separators(self, interface): - """ delete all interface separators """ + """delete all interface separators""" todel = [] - separators = self.pfsense.rules.find('separator') or [] + separators = self.pfsense.rules.find("separator") or [] for interface_elt in separators: if interface_elt.tag != interface: continue for separator_elt in interface_elt: todel.append(separator_elt) for separator_elt in todel: - cmd = 'delete rule_separator \'{0}\', interface=\'{1}\''.format(separator_elt.find('text').text, interface) - self.result['commands'].append(cmd) + cmd = "delete rule_separator '{0}', interface='{1}'".format( + separator_elt.find("text").text, interface + ) + self.result["commands"].append(cmd) interface_elt.remove(separator_elt) separators.remove(interface_elt) break @@ -460,27 +583,31 @@ def _get_interface_list(self): "$ipsec_descrs = interface_ipsec_vti_list_all();" "foreach ($ipsec_descrs as $ifname => $ifdescr) $portlist[$ifname] = array('descr' => $ifdescr);" "" - "echo json_encode($portlist, JSON_PRETTY_PRINT);") + "echo json_encode($portlist, JSON_PRETTY_PRINT);" + ) def _get_media_mode(self, interface): - """ Find all possible media options for the interface """ + """Find all possible media options for the interface""" return self.pfsense.php( - '$mediaopts_list = array();\n' - 'exec("/sbin/ifconfig -m ' + interface + ' | grep \'media \'", $mediaopts);\n' - 'foreach ($mediaopts as $mediaopt) {\n' + "$mediaopts_list = array();\n" + 'exec("/sbin/ifconfig -m ' + + interface + + " | grep 'media '\", $mediaopts);\n" + "foreach ($mediaopts as $mediaopt) {\n" ' preg_match("/media (.*)/", $mediaopt, $matches);\n' ' if (preg_match("/(.*) mediaopt (.*)/", $matches[1], $matches1)) {\n' ' // there is media + mediaopt like "media 1000baseT mediaopt full-duplex"\n' ' array_push($mediaopts_list, $matches1[1] . " " . $matches1[2]);\n' - ' } else {\n' + " } else {\n" ' // there is only media like "media 1000baseT"\n' - ' array_push($mediaopts_list, $matches[1]);\n' - ' }\n' - '}\n' - 'echo json_encode($mediaopts_list);') + " array_push($mediaopts_list, $matches[1]);\n" + " }\n" + "}\n" + "echo json_encode($mediaopts_list);" + ) def get_pre_update_cmds(self): - """ build and return php commands to setup interfaces before changing config """ + """build and return php commands to setup interfaces before changing config""" cmd = 'require_once("filter.inc");\n' cmd += 'require_once("interfaces.inc");\n' @@ -490,7 +617,7 @@ def get_pre_update_cmds(self): return cmd def get_update_cmds(self): - """ build and return php commands to setup interfaces """ + """build and return php commands to setup interfaces""" cmd = 'require_once("filter.inc");\n' cmd += 'require_once("interfaces.inc");\n' cmd += 'require_once("services.inc");\n' @@ -501,8 +628,8 @@ def get_update_cmds(self): if self.setup_interface_cmds != "": cmd += self.setup_interface_cmds - cmd += 'services_snmpd_configure();\n' - cmd += 'setup_gateways_monitor();\n' + cmd += "services_snmpd_configure();\n" + cmd += "setup_gateways_monitor();\n" cmd += "clear_subsystem_dirty('interfaces');\n" cmd += "filter_configure();\n" cmd += "enable_rrd_graphing();\n" @@ -510,76 +637,135 @@ def get_update_cmds(self): return cmd def _pre_update(self): - """ tasks to run before making config changes """ + """tasks to run before making config changes""" return self.pfsense.phpshell(self.get_pre_update_cmds()) def _update(self): - """ make the target pfsense reload interfaces """ + """make the target pfsense reload interfaces""" return self.pfsense.phpshell(self.get_update_cmds()) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.obj['descr']) + """return obj's name""" + return "'{0}'".format(self.obj["descr"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'if', fname='port') - values += self.format_cli_field(self.obj, 'enable', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'ipv4_type', default='none') - values += self.format_cli_field(self.obj, 'ipaddr', fname='ipv4_address') - values += self.format_cli_field(self.obj, 'subnet', fname='ipv4_prefixlen') - values += self.format_cli_field(self.obj, 'gateway', fname='ipv4_gateway') - values += self.format_cli_field(self.params, 'ipv6_type', default='none') - if self.obj.get('ipaddrv6') != 'slaac': - values += self.format_cli_field(self.obj, 'ipaddrv6', fname='ipv6_address') - values += self.format_cli_field(self.obj, 'subnetv6', fname='ipv6_prefixlen') - values += self.format_cli_field(self.obj, 'gatewayv6', fname='ipv6_gateway') - values += self.format_cli_field(self.params, 'mac') - values += self.format_cli_field(self.obj, 'mtu') - values += self.format_cli_field(self.obj, 'mss') - values += self.format_cli_field(self.obj, 'blockpriv', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.obj, 'blockbogons', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'speed_duplex', fname='speed_duplex', default='autoselect') + values += self.format_cli_field(self.obj, "if", fname="port") + values += self.format_cli_field(self.obj, "enable", fvalue=self.fvalue_bool) + values += self.format_cli_field(self.params, "ipv4_type", default="none") + values += self.format_cli_field(self.obj, "ipaddr", fname="ipv4_address") + values += self.format_cli_field(self.obj, "subnet", fname="ipv4_prefixlen") + values += self.format_cli_field(self.obj, "gateway", fname="ipv4_gateway") + values += self.format_cli_field(self.params, "ipv6_type", default="none") + if self.obj.get("ipaddrv6") != "slaac": + values += self.format_cli_field( + self.obj, "ipaddrv6", fname="ipv6_address" + ) + values += self.format_cli_field( + self.obj, "subnetv6", fname="ipv6_prefixlen" + ) + values += self.format_cli_field(self.obj, "gatewayv6", fname="ipv6_gateway") + values += self.format_cli_field(self.params, "mac") + values += self.format_cli_field(self.obj, "mtu") + values += self.format_cli_field(self.obj, "mss") + values += self.format_cli_field( + self.obj, "blockpriv", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.obj, "blockbogons", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.params, "speed_duplex", fname="speed_duplex", default="autoselect" + ) else: # todo: - detect before ipv4_type for proper logging - values += self.format_updated_cli_field(self.obj, before, 'descr', add_comma=(values), fname='interface') - values += self.format_updated_cli_field(self.obj, before, 'if', add_comma=(values), fname='port') - values += self.format_updated_cli_field(self.obj, before, 'enable', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'ipv4_type', add_comma=(values), log_none='True') - values += self.format_updated_cli_field(self.obj, before, 'ipaddr', add_comma=(values), fname='ipv4_address') - values += self.format_updated_cli_field(self.obj, before, 'subnet', add_comma=(values), fname='ipv4_prefixlen') - values += self.format_updated_cli_field(self.obj, before, 'gateway', add_comma=(values), fname='ipv4_gateway') - if self.obj.get('ipaddrv6') == 'slaac' and before.get('ipaddrv6') != 'slaac': + values += self.format_updated_cli_field( + self.obj, before, "descr", add_comma=(values), fname="interface" + ) + values += self.format_updated_cli_field( + self.obj, before, "if", add_comma=(values), fname="port" + ) + values += self.format_updated_cli_field( + self.obj, before, "enable", add_comma=(values), fvalue=self.fvalue_bool + ) + values += self.format_updated_cli_field( + self.obj, before, "ipv4_type", add_comma=(values), log_none="True" + ) + values += self.format_updated_cli_field( + self.obj, before, "ipaddr", add_comma=(values), fname="ipv4_address" + ) + values += self.format_updated_cli_field( + self.obj, before, "subnet", add_comma=(values), fname="ipv4_prefixlen" + ) + values += self.format_updated_cli_field( + self.obj, before, "gateway", add_comma=(values), fname="ipv4_gateway" + ) + if ( + self.obj.get("ipaddrv6") == "slaac" + and before.get("ipaddrv6") != "slaac" + ): res = "ipv6_type=slaac" if values: values += ", " + res else: values += res else: - values += self.format_updated_cli_field(self.obj, before, 'ipv6_type', add_comma=(values), log_none='True') - values += self.format_updated_cli_field(self.obj, before, 'ipaddrv6', add_comma=(values), fname='ipv6_address') - values += self.format_updated_cli_field(self.obj, before, 'subnetv6', add_comma=(values), fname='ipv6_prefixlen') - values += self.format_updated_cli_field(self.obj, before, 'gatewayv6', add_comma=(values), fname='ipv6_gateway') - values += self.format_updated_cli_field(self.obj, before, 'spoofmac', add_comma=(values), fname='mac') - values += self.format_updated_cli_field(self.obj, before, 'mtu', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'mss', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'media', add_comma=(values), fname='speed_duplex') - values += self.format_updated_cli_field(self.obj, before, 'blockpriv', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'blockbogons', add_comma=(values), fvalue=self.fvalue_bool) + values += self.format_updated_cli_field( + self.obj, before, "ipv6_type", add_comma=(values), log_none="True" + ) + values += self.format_updated_cli_field( + self.obj, + before, + "ipaddrv6", + add_comma=(values), + fname="ipv6_address", + ) + values += self.format_updated_cli_field( + self.obj, before, "subnetv6", add_comma=(values), fname="ipv6_prefixlen" + ) + values += self.format_updated_cli_field( + self.obj, before, "gatewayv6", add_comma=(values), fname="ipv6_gateway" + ) + values += self.format_updated_cli_field( + self.obj, before, "spoofmac", add_comma=(values), fname="mac" + ) + values += self.format_updated_cli_field( + self.obj, before, "mtu", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "mss", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "media", add_comma=(values), fname="speed_duplex" + ) + values += self.format_updated_cli_field( + self.obj, + before, + "blockpriv", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "blockbogons", + add_comma=(values), + fvalue=self.fvalue_bool, + ) return values def _log_update(self, before): - """ generate pseudo-CLI command to update an interface """ + """generate pseudo-CLI command to update an interface""" log = "update {0} '{1}'".format( self._get_module_name(True), # pfSense doesn't enforce a descr on an interface, especially on # first-run so fallback to interface specifier if not known - before.get('descr', before['if']), + before.get("descr", before["if"]), ) values = self._log_fields(before) - self.result['commands'].append(log + ' set ' + values) + self.result["commands"].append(log + " set " + values) diff --git a/plugins/module_utils/interface_group.py b/plugins/module_utils/interface_group.py index 5744047b..b0274c85 100644 --- a/plugins/module_utils/interface_group.py +++ b/plugins/module_utils/interface_group.py @@ -4,113 +4,143 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase -from ansible_collections.pfsensible.core.plugins.module_utils.rule import PFSenseRuleModule +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule import ( + PFSenseRuleModule, +) INTERFACE_GROUP_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - name=dict(required=True, type='str'), - descr=dict(type='str'), - members=dict(type='list', elements='str'), + state=dict(default="present", choices=["present", "absent"]), + name=dict(required=True, type="str"), + descr=dict(type="str"), + members=dict(type="list", elements="str"), ) INTERFACE_GROUP_REQUIRED_IF = [ - ['state', 'present', ['members']], + ["state", "present", ["members"]], ] -INTERFACE_GROUP_PHP_COMMAND = ''' +INTERFACE_GROUP_PHP_COMMAND = """ require_once("interfaces.inc"); {0} -interface_group_setup($ifgroupentry);''' +interface_group_setup($ifgroupentry);""" class PFSenseInterfaceGroupModule(PFSenseModuleBase): - """ module managing pfsense interfaces """ + """module managing pfsense interfaces""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return INTERFACE_GROUP_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseInterfaceGroupModule, self).__init__(module, pfsense, root='ifgroups', create_root=True, node='ifgroupentry', key='ifname') + super(PFSenseInterfaceGroupModule, self).__init__( + module, + pfsense, + root="ifgroups", + create_root=True, + node="ifgroupentry", + key="ifname", + ) self.name = "pfsense_interface_group" ############################## # params processing # def _params_to_obj(self): - """ return an interface dict from module params """ + """return an interface dict from module params""" params = self.params obj = dict() self.obj = obj - obj['ifname'] = params['name'] - if params['state'] == 'present': - obj['descr'] = params['descr'] + obj["ifname"] = params["name"] + if params["state"] == "present": + obj["descr"] = params["descr"] members = [] - for interface in params['members']: + for interface in params["members"]: if self.pfsense.is_interface_display_name(interface): - members.append(self.pfsense.get_interface_by_display_name(interface)) + members.append( + self.pfsense.get_interface_by_display_name(interface) + ) elif self.pfsense.is_interface_port(interface): members.append(interface) else: - self.module.fail_json(msg='Unknown interface name "{0}".'.format(interface)) - obj['members'] = ' '.join(members) - self.result['member_ifnames'] = members + self.module.fail_json( + msg='Unknown interface name "{0}".'.format(interface) + ) + obj["members"] = " ".join(members) + self.result["member_ifnames"] = members return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - if re.match('^[a-zA-Z0-9_]+$', params['name']) is None: - self.module.fail_json(msg='The name of the interface group may only consist of the characters "a-z, A-Z, 0-9 and _"') - if len(params['name']) > 15: - self.module.fail_json(msg='Group name cannot have more than 15 characters.') - if re.match('[0-9]$', params['name']) is not None: - self.module.fail_json(msg='Group name cannot end with a digit.') + if re.match("^[a-zA-Z0-9_]+$", params["name"]) is None: + self.module.fail_json( + msg='The name of the interface group may only consist of the characters "a-z, A-Z, 0-9 and _"' + ) + if len(params["name"]) > 15: + self.module.fail_json(msg="Group name cannot have more than 15 characters.") + if re.match("[0-9]$", params["name"]) is not None: + self.module.fail_json(msg="Group name cannot end with a digit.") # Make sure list of interfaces is a unique set - if params['state'] == 'present': - if len(params['members']) > len(set(params['members'])): - self.module.fail_json(msg='List of members is not unique.') + if params["state"] == "present": + if len(params["members"]) > len(set(params["members"])): + self.module.fail_json(msg="List of members is not unique.") # TODO - check that name isn't in use by any interfaces ############################## # XML processing # def _remove_all_rules(self, interface): - """ delete all interface rules """ + """delete all interface rules""" # we use the pfsense_rule module to delete the rules since, at least for floating rules, # it implies to recalculate separators positions # if we have to just remove the deleted interface of a floating rule we do it ourselves todel = [] for rule_elt in self.pfsense.rules: - if rule_elt.find('floating') is not None: - interfaces = rule_elt.find('interface').text.split(',') - old_ifs = ','.join([self.pfsense.get_interface_display_name(old_interface) for old_interface in interfaces]) + if rule_elt.find("floating") is not None: + interfaces = rule_elt.find("interface").text.split(",") + old_ifs = ",".join( + [ + self.pfsense.get_interface_display_name(old_interface) + for old_interface in interfaces + ] + ) if interface in interfaces: if len(interfaces) > 1: interfaces.remove(interface) - new_ifs = ','.join([self.pfsense.get_interface_display_name(new_interface) for new_interface in interfaces]) - rule_elt.find('interface').text = ','.join(interfaces) - cmd = 'update rule \'{0}\' on \'floating({1})\' set interface=\'{2}\''.format(rule_elt.find('descr').text, old_ifs, new_ifs) - self.result['commands'].append(cmd) + new_ifs = ",".join( + [ + self.pfsense.get_interface_display_name(new_interface) + for new_interface in interfaces + ] + ) + rule_elt.find("interface").text = ",".join(interfaces) + cmd = "update rule '{0}' on 'floating({1})' set interface='{2}'".format( + rule_elt.find("descr").text, old_ifs, new_ifs + ) + self.result["commands"].append(cmd) continue todel.append(rule_elt) else: continue else: - iface = rule_elt.find('interface') + iface = rule_elt.find("interface") if iface is not None and iface.text == interface: todel.append(rule_elt) @@ -118,27 +148,29 @@ def _remove_all_rules(self, interface): pfsense_rules = PFSenseRuleModule(self.module, self.pfsense) for rule_elt in todel: params = {} - params['state'] = 'absent' - params['name'] = rule_elt.find('descr').text - params['interface'] = rule_elt.find('interface').text - if rule_elt.find('floating') is not None: - params['floating'] = True + params["state"] = "absent" + params["name"] = rule_elt.find("descr").text + params["interface"] = rule_elt.find("interface").text + if rule_elt.find("floating") is not None: + params["floating"] = True pfsense_rules.run(params) - if pfsense_rules.result['commands']: - self.result['commands'].extend(pfsense_rules.result['commands']) + if pfsense_rules.result["commands"]: + self.result["commands"].extend(pfsense_rules.result["commands"]) def _remove_all_separators(self, interface): - """ delete all interface separators """ + """delete all interface separators""" todel = [] - separators = self.pfsense.rules.find('separator') + separators = self.pfsense.rules.find("separator") for interface_elt in separators: if interface_elt.tag != interface: continue for separator_elt in interface_elt: todel.append(separator_elt) for separator_elt in todel: - cmd = 'delete rule_separator \'{0}\', interface=\'{1}\''.format(separator_elt.find('text').text, interface) - self.result['commands'].append(cmd) + cmd = "delete rule_separator '{0}', interface='{1}'".format( + separator_elt.find("text").text, interface + ) + self.result["commands"].append(cmd) interface_elt.remove(separator_elt) separators.remove(interface_elt) break @@ -148,19 +180,27 @@ def _remove_all_separators(self, interface): # def _update(self): - """ make the target pfsense reload interfaces """ - return self.pfsense.phpshell(INTERFACE_GROUP_PHP_COMMAND.format(self.pfsense.dict_to_php(self.obj, 'ifgroupentry'))) + """make the target pfsense reload interfaces""" + return self.pfsense.phpshell( + INTERFACE_GROUP_PHP_COMMAND.format( + self.pfsense.dict_to_php(self.obj, "ifgroupentry") + ) + ) ############################## # Logging # def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'descr') - values += self.format_cli_field(self.obj, 'members') + values += self.format_cli_field(self.obj, "descr") + values += self.format_cli_field(self.obj, "members") else: - values += self.format_updated_cli_field(self.obj, before, 'descr', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'members', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "descr", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "members", add_comma=(values) + ) return values diff --git a/plugins/module_utils/ipsec.py b/plugins/module_utils/ipsec.py index 2966c981..6d1f5670 100644 --- a/plugins/module_utils/ipsec.py +++ b/plugins/module_utils/ipsec.py @@ -4,59 +4,83 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) IPSEC_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - descr=dict(required=True, type='str'), - iketype=dict(choices=['ikev1', 'ikev2', 'auto'], type='str'), - protocol=dict(default='inet', choices=['inet', 'inet6', 'both']), - interface=dict(required=False, type='str'), - remote_gateway=dict(required=False, type='str'), - nattport=dict(required=False, type='int'), - - disabled=dict(required=False, type='bool'), - - authentication_method=dict(choices=['pre_shared_key', 'rsasig']), - mode=dict(required=False, choices=['main', 'aggressive']), - myid_type=dict(default='myaddress', choices=['myaddress', 'address', 'fqdn', 'user_fqdn', 'asn1dn', 'keyid tag', 'dyn_dns', 'auto']), - myid_data=dict(required=False, type='str'), - peerid_type=dict(default='peeraddress', choices=['any', 'peeraddress', 'address', 'fqdn', 'user_fqdn', 'asn1dn', 'keyid tag', 'auto']), - peerid_data=dict(required=False, type='str'), - certificate=dict(required=False, type='str'), - certificate_authority=dict(required=False, type='str'), - preshared_key=dict(required=False, type='str', no_log=True), - - lifetime=dict(default=28800, type='int'), - rekey_time=dict(required=False, type='int'), - reauth_time=dict(required=False, type='int'), - rand_time=dict(required=False, type='int'), - - disable_rekey=dict(required=False, type='bool'), - margintime=dict(required=False, type='int'), - startaction=dict(default='', choices=['', 'none', 'start', 'trap']), - closeaction=dict(default='', choices=['', 'none', 'start', 'trap']), - disable_reauth=dict(default=False, type='bool'), - mobike=dict(default='off', choices=['on', 'off']), - gw_duplicates=dict(required=False, type='bool'), - splitconn=dict(default=False, type='bool'), - - nat_traversal=dict(default='on', choices=['on', 'force']), - enable_dpd=dict(default=True, type='bool'), - dpd_delay=dict(default=10, type='int'), - dpd_maxfail=dict(default=5, type='int'), - apply=dict(default=True, type='bool'), - + state=dict(default="present", choices=["present", "absent"]), + descr=dict(required=True, type="str"), + iketype=dict(choices=["ikev1", "ikev2", "auto"], type="str"), + protocol=dict(default="inet", choices=["inet", "inet6", "both"]), + interface=dict(required=False, type="str"), + remote_gateway=dict(required=False, type="str"), + nattport=dict(required=False, type="int"), + disabled=dict(required=False, type="bool"), + authentication_method=dict(choices=["pre_shared_key", "rsasig"]), + mode=dict(required=False, choices=["main", "aggressive"]), + myid_type=dict( + default="myaddress", + choices=[ + "myaddress", + "address", + "fqdn", + "user_fqdn", + "asn1dn", + "keyid tag", + "dyn_dns", + "auto", + ], + ), + myid_data=dict(required=False, type="str"), + peerid_type=dict( + default="peeraddress", + choices=[ + "any", + "peeraddress", + "address", + "fqdn", + "user_fqdn", + "asn1dn", + "keyid tag", + "auto", + ], + ), + peerid_data=dict(required=False, type="str"), + certificate=dict(required=False, type="str"), + certificate_authority=dict(required=False, type="str"), + preshared_key=dict(required=False, type="str", no_log=True), + lifetime=dict(default=28800, type="int"), + rekey_time=dict(required=False, type="int"), + reauth_time=dict(required=False, type="int"), + rand_time=dict(required=False, type="int"), + disable_rekey=dict(required=False, type="bool"), + margintime=dict(required=False, type="int"), + startaction=dict(default="", choices=["", "none", "start", "trap"]), + closeaction=dict(default="", choices=["", "none", "start", "trap"]), + disable_reauth=dict(default=False, type="bool"), + mobike=dict(default="off", choices=["on", "off"]), + gw_duplicates=dict(required=False, type="bool"), + splitconn=dict(default=False, type="bool"), + nat_traversal=dict(default="on", choices=["on", "force"]), + enable_dpd=dict(default=True, type="bool"), + dpd_delay=dict(default=10, type="int"), + dpd_maxfail=dict(default=5, type="int"), + apply=dict(default=True, type="bool"), # Dropped in 2.5.2 - responderonly=dict(required=False, type='bool'), + responderonly=dict(required=False, type="bool"), ) IPSEC_REQUIRED_IF = [ - ["state", "present", ["remote_gateway", "interface", "iketype", "authentication_method"]], - + [ + "state", + "present", + ["remote_gateway", "interface", "iketype", "authentication_method"], + ], ["enable_dpd", True, ["dpd_delay", "dpd_maxfail"]], ["iketype", "auto", ["mode"]], ["iketype", "ikev1", ["mode"]], @@ -68,7 +92,6 @@ ["myid_type", "asn1dn", ["myid_data"]], ["myid_type", "keyid tag", ["myid_data"]], ["myid_type", "dyn_dns", ["myid_data"]], - ["peerid_type", "address", ["peerid_data"]], ["peerid_type", "fqdn", ["peerid_data"]], ["peerid_type", "user_fqdn", ["peerid_data"]], @@ -78,12 +101,12 @@ # Booleans that map to different values IPSEC_BOOL_VALUES = dict( - gw_duplicates=(None, ''), + gw_duplicates=(None, ""), ) IPSEC_MAP_PARAM = [ - ('preshared_key', 'pre-shared-key'), - ('remote_gateway', 'remote-gateway'), + ("preshared_key", "pre-shared-key"), + ("remote_gateway", "remote-gateway"), ] IPSEC_CREATE_DEFAULT = dict( @@ -96,31 +119,39 @@ def p2o_ipsec_interface(self, name, params, obj): # Valid interfaces are physical, virtual IPs, and gateway groups # TODO - handle gateway groups - if params[name].lower().startswith('vip:'): + if params[name].lower().startswith("vip:"): obj[name] = self.pfsense.get_virtual_ip_interface(params[name][4:]) else: obj[name] = self.pfsense.parse_interface(params[name], with_virtual=False) IPSEC_ARG_ROUTE = dict( - interface=dict(parse=p2o_ipsec_interface,), + interface=dict( + parse=p2o_ipsec_interface, + ), ) class PFSenseIpsecModule(PFSenseModuleBase): - """ module managing pfsense ipsec tunnels phase 1 options """ + """module managing pfsense ipsec tunnels phase 1 options""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return IPSEC_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseIpsecModule, self).__init__(module, pfsense, arg_route=IPSEC_ARG_ROUTE, bool_values=IPSEC_BOOL_VALUES, map_param=IPSEC_MAP_PARAM, - create_default=IPSEC_CREATE_DEFAULT) + super(PFSenseIpsecModule, self).__init__( + module, + pfsense, + arg_route=IPSEC_ARG_ROUTE, + bool_values=IPSEC_BOOL_VALUES, + map_param=IPSEC_MAP_PARAM, + create_default=IPSEC_CREATE_DEFAULT, + ) # Override for use with aggregate self.argument_spec = IPSEC_ARGUMENT_SPEC self.name = "pfsense_ipsec" @@ -132,18 +163,18 @@ def __init__(self, module, pfsense=None): # XML processing # def _create_target(self): - """ create the XML target_elt """ - ipsec_elt = self.pfsense.new_element('phase1') - self.obj['ikeid'] = str(self._find_free_ikeid()) + """create the XML target_elt""" + ipsec_elt = self.pfsense.new_element("phase1") + self.obj["ikeid"] = str(self._find_free_ikeid()) return ipsec_elt def _find_free_ikeid(self): - """ return first unused ikeid """ + """return first unused ikeid""" ikeid = 1 while True: found = False for ipsec_elt in self.root_elt: - ikeid_elt = ipsec_elt.find('ikeid') + ikeid_elt = ipsec_elt.find("ikeid") if ikeid_elt is not None and ikeid_elt.text == str(ikeid): found = True break @@ -153,36 +184,43 @@ def _find_free_ikeid(self): ikeid = ikeid + 1 def _find_target(self): - """ find the XML target_elt """ - if self.params.get('ikeid') is not None: - return self.pfsense.find_ipsec_phase1(self.params['ikeid'], 'ikeid') - return self.pfsense.find_ipsec_phase1(self.obj['descr']) + """find the XML target_elt""" + if self.params.get("ikeid") is not None: + return self.pfsense.find_ipsec_phase1(self.params["ikeid"], "ikeid") + return self.pfsense.find_ipsec_phase1(self.obj["descr"]) def _get_params_to_remove(self): - """ returns the list of params to remove if they are not set """ - params = ['disabled', 'rekey_enable', 'reauth_enable', 'splitconn', 'nattport', 'gw_duplicates'] - if self.params.get('disable_rekey'): - params.append('margintime') + """returns the list of params to remove if they are not set""" + params = [ + "disabled", + "rekey_enable", + "reauth_enable", + "splitconn", + "nattport", + "gw_duplicates", + ] + if self.params.get("disable_rekey"): + params.append("margintime") - if not self.params['enable_dpd']: - params.append('dpd_delay') - params.append('dpd_maxfail') + if not self.params["enable_dpd"]: + params.append("dpd_delay") + params.append("dpd_maxfail") return params def _pre_remove_target_elt(self): - """ processing before removing elt """ + """processing before removing elt""" self._remove_phases2() def _remove_phases2(self): - """ remove phase2 elts from xml """ - ikeid_elt = self.target_elt.find('ikeid') + """remove phase2 elts from xml""" + ikeid_elt = self.target_elt.find("ikeid") if ikeid_elt is None: return ikeid = ikeid_elt.text - phase2_elts = self.root_elt.findall('phase2') + phase2_elts = self.root_elt.findall("phase2") for phase2_elt in phase2_elts: - ikeid_elt = phase2_elt.find('ikeid') + ikeid_elt = phase2_elt.find("ikeid") if ikeid_elt is None: continue if ikeid == ikeid_elt.text: @@ -192,221 +230,360 @@ def _remove_phases2(self): # params processing # def _params_to_obj(self): - """ return an ipsec dict from module params """ + """return an ipsec dict from module params""" ipsec = super(PFSenseIpsecModule, self)._params_to_obj() params = self.params - self.apply = params['apply'] - ipsec.pop('apply', None) + self.apply = params["apply"] + ipsec.pop("apply", None) - if params['state'] == 'present': - if params['authentication_method'] == 'rsasig': - ca_elt = self.pfsense.find_ca_elt(params['certificate_authority']) + if params["state"] == "present": + if params["authentication_method"] == "rsasig": + ca_elt = self.pfsense.find_ca_elt(params["certificate_authority"]) if ca_elt is None: - self.module.fail_json(msg='%s is not a valid certificate authority' % (params['certificate_authority'])) - ipsec['caref'] = ca_elt.find('refid').text + self.module.fail_json( + msg="%s is not a valid certificate authority" + % (params["certificate_authority"]) + ) + ipsec["caref"] = ca_elt.find("refid").text - cert = self.pfsense.find_cert_elt(params['certificate']) + cert = self.pfsense.find_cert_elt(params["certificate"]) if cert is None: - self.module.fail_json(msg='%s is not a valid certificate' % (params['certificate'])) - ipsec['certref'] = cert.find('refid').text - ipsec['pre-shared-key'] = '' + self.module.fail_json( + msg="%s is not a valid certificate" % (params["certificate"]) + ) + ipsec["certref"] = cert.find("refid").text + ipsec["pre-shared-key"] = "" else: - ipsec['caref'] = '' - ipsec['certref'] = '' + ipsec["caref"] = "" + ipsec["certref"] = "" - if params.get('disable_rekey'): - ipsec['rekey_enable'] = '' + if params.get("disable_rekey"): + ipsec["rekey_enable"] = "" - if params.get('enable_dpd'): - ipsec['dpd_delay'] = str(params['dpd_delay']) - ipsec['dpd_maxfail'] = str(params['dpd_maxfail']) - del ipsec['enable_dpd'] + if params.get("enable_dpd"): + ipsec["dpd_delay"] = str(params["dpd_delay"]) + ipsec["dpd_maxfail"] = str(params["dpd_maxfail"]) + del ipsec["enable_dpd"] - if params.get('disable_reauth'): - ipsec['reauth_enable'] = '' + if params.get("disable_reauth"): + ipsec["reauth_enable"] = "" return ipsec def _deprecated_params(self): return [ - ['disable_rekey', self.pfsense.is_at_least_2_5_0], - ['margintime', self.pfsense.is_at_least_2_5_0], - ['responderonly', self.pfsense.is_at_least_2_5_2], + ["disable_rekey", self.pfsense.is_at_least_2_5_0], + ["margintime", self.pfsense.is_at_least_2_5_0], + ["responderonly", self.pfsense.is_at_least_2_5_2], ] def _onward_params(self): return [ - ['gw_duplicates', self.pfsense.is_at_least_2_5_0], - ['nattport', self.pfsense.is_at_least_2_5_0], - ['rekey_time', self.pfsense.is_at_least_2_5_0], - ['reauth_time', self.pfsense.is_at_least_2_5_0], - ['rand_time', self.pfsense.is_at_least_2_5_0], + ["gw_duplicates", self.pfsense.is_at_least_2_5_0], + ["nattport", self.pfsense.is_at_least_2_5_0], + ["rekey_time", self.pfsense.is_at_least_2_5_0], + ["reauth_time", self.pfsense.is_at_least_2_5_0], + ["rand_time", self.pfsense.is_at_least_2_5_0], # TODO - Cannot add because it has a default value # ['startaction', self.pfsense.is_at_least_2_5_2], # ['closeaction', self.pfsense.is_at_least_2_5_2], ] def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params['state'] == 'absent': + if params["state"] == "absent": return - if params.get('lifetime') is not None: - if (params.get('rekey_time') is not None and params.get('rekey_time') >= params.get('lifetime') or - params.get('reauth_time') is not None and params.get('reauth_time') >= params.get('lifetime')): - self.module.fail_json(msg='Life Time must be larger than Rekey Time and Reauth Time.') + if params.get("lifetime") is not None: + if ( + params.get("rekey_time") is not None + and params.get("rekey_time") >= params.get("lifetime") + or params.get("reauth_time") is not None + and params.get("reauth_time") >= params.get("lifetime") + ): + self.module.fail_json( + msg="Life Time must be larger than Rekey Time and Reauth Time." + ) for ipsec_elt in self.root_elt: - if ipsec_elt.tag != 'phase1': + if ipsec_elt.tag != "phase1": continue # don't check on ourself - name = ipsec_elt.find('descr') + name = ipsec_elt.find("descr") if name is None: - name = '' + name = "" else: name = name.text - if name == params['descr']: + if name == params["descr"]: continue # Valid interfaces are physical, virtual IPs, and gateway groups # TODO - handle gateway groups - if params['interface'].lower().startswith('vip:'): - if self.pfsense.get_virtual_ip_interface(params['interface'][4:]) is None: - self.module.fail_json(msg='Cannot find virtual IP "{0}".'.format(params['interface'][4:])) + if params["interface"].lower().startswith("vip:"): + if ( + self.pfsense.get_virtual_ip_interface(params["interface"][4:]) + is None + ): + self.module.fail_json( + msg='Cannot find virtual IP "{0}".'.format( + params["interface"][4:] + ) + ) # two ikev2 can share the same gateway - iketype_elt = ipsec_elt.find('iketype') + iketype_elt = ipsec_elt.find("iketype") if iketype_elt is None: continue - if iketype_elt.text == 'ikev2' and iketype_elt.text == params['iketype']: + if iketype_elt.text == "ikev2" and iketype_elt.text == params["iketype"]: continue # others can't share the same gateway - rgw_elt = ipsec_elt.find('remote-gateway') + rgw_elt = ipsec_elt.find("remote-gateway") if rgw_elt is None: continue - if rgw_elt.text == params['remote_gateway']: - self.module.fail_json(msg='The remote gateway "{0}" is already used by phase1 "{1}".'.format(params['remote_gateway'], name)) + if rgw_elt.text == params["remote_gateway"]: + self.module.fail_json( + msg='The remote gateway "{0}" is already used by phase1 "{1}".'.format( + params["remote_gateway"], name + ) + ) ############################## # run # def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" return self.pfsense.apply_ipsec_changes() ############################## # Logging # def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.diff['after'], 'iketype') - if self.diff['after']['iketype'] != 'ikev2': - values += self.format_cli_field(self.diff['after'], 'mode') - - values += self.format_cli_field(self.diff['after'], 'protocol') - values += self.format_cli_field(self.params, 'interface') - values += self.format_cli_field(self.diff['after'], 'remote-gateway', fname='remote_gateway') - values += self.format_cli_field(self.diff['after'], 'nattport') - values += self.format_cli_field(self.diff['after'], 'authentication_method') - if self.diff['after']['authentication_method'] == 'rsasig': - values += self.format_cli_field(self.params, 'certificate') - values += self.format_cli_field(self.params, 'certificate_authority') + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool + ) + values += self.format_cli_field(self.diff["after"], "iketype") + if self.diff["after"]["iketype"] != "ikev2": + values += self.format_cli_field(self.diff["after"], "mode") + + values += self.format_cli_field(self.diff["after"], "protocol") + values += self.format_cli_field(self.params, "interface") + values += self.format_cli_field( + self.diff["after"], "remote-gateway", fname="remote_gateway" + ) + values += self.format_cli_field(self.diff["after"], "nattport") + values += self.format_cli_field(self.diff["after"], "authentication_method") + if self.diff["after"]["authentication_method"] == "rsasig": + values += self.format_cli_field(self.params, "certificate") + values += self.format_cli_field(self.params, "certificate_authority") else: - values += self.format_cli_field(self.diff['after'], 'pre-shared-key', fname='preshared_key') - - id_types = ['address', 'fqdn', 'user_fqdn', 'asn1dn', 'keyid tag', 'dyn_dns'] - values += self.format_cli_field(self.diff['after'], 'myid_type') - if self.diff['after']['myid_type'] in id_types: - values += self.format_cli_field(self.diff['after'], 'myid_data') - - values += self.format_cli_field(self.diff['after'], 'peerid_type') - if self.diff['after']['peerid_type'] in id_types: - values += self.format_cli_field(self.diff['after'], 'peerid_data') - - values += self.format_cli_field(self.diff['after'], 'lifetime') - values += self.format_cli_field(self.diff['after'], 'rekey_time') - values += self.format_cli_field(self.diff['after'], 'reauth_time') - values += self.format_cli_field(self.diff['after'], 'rand_time') - - if self.diff['after']['iketype'] == 'ikev2': - values += self.format_cli_field(self.diff['after'], 'reauth_enable', fname='disable_reauth', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.diff['after'], 'mobike') - values += self.format_cli_field(self.diff['after'], 'splitconn', fvalue=self.fvalue_bool) - - values += self.format_cli_field(self.diff['after'], 'gw_duplicates', fvalue=self.fvalue_bool) - - values += self.format_cli_field(self.params, 'startaction') - values += self.format_cli_field(self.params, 'closeaction') - values += self.format_cli_field(self.diff['after'], 'nat_traversal') - - values += self.format_cli_field(self.params, 'enable_dpd', fvalue=self.fvalue_bool) - if self.params['enable_dpd']: - values += self.format_cli_field(self.diff['after'], 'dpd_delay') - values += self.format_cli_field(self.diff['after'], 'dpd_maxfail') + values += self.format_cli_field( + self.diff["after"], "pre-shared-key", fname="preshared_key" + ) + + id_types = [ + "address", + "fqdn", + "user_fqdn", + "asn1dn", + "keyid tag", + "dyn_dns", + ] + values += self.format_cli_field(self.diff["after"], "myid_type") + if self.diff["after"]["myid_type"] in id_types: + values += self.format_cli_field(self.diff["after"], "myid_data") + + values += self.format_cli_field(self.diff["after"], "peerid_type") + if self.diff["after"]["peerid_type"] in id_types: + values += self.format_cli_field(self.diff["after"], "peerid_data") + + values += self.format_cli_field(self.diff["after"], "lifetime") + values += self.format_cli_field(self.diff["after"], "rekey_time") + values += self.format_cli_field(self.diff["after"], "reauth_time") + values += self.format_cli_field(self.diff["after"], "rand_time") + + if self.diff["after"]["iketype"] == "ikev2": + values += self.format_cli_field( + self.diff["after"], + "reauth_enable", + fname="disable_reauth", + fvalue=self.fvalue_bool, + ) + values += self.format_cli_field(self.diff["after"], "mobike") + values += self.format_cli_field( + self.diff["after"], "splitconn", fvalue=self.fvalue_bool + ) + + values += self.format_cli_field( + self.diff["after"], "gw_duplicates", fvalue=self.fvalue_bool + ) + + values += self.format_cli_field(self.params, "startaction") + values += self.format_cli_field(self.params, "closeaction") + values += self.format_cli_field(self.diff["after"], "nat_traversal") + + values += self.format_cli_field( + self.params, "enable_dpd", fvalue=self.fvalue_bool + ) + if self.params["enable_dpd"]: + values += self.format_cli_field(self.diff["after"], "dpd_delay") + values += self.format_cli_field(self.diff["after"], "dpd_maxfail") else: - values += self.format_updated_cli_field(self.diff['after'], before, 'disabled', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.diff['after'], before, 'iketype', add_comma=(values)) - if self.diff['after']['iketype'] != 'ikev2': - values += self.format_updated_cli_field(self.diff['after'], before, 'mode', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'protocol', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'interface', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'remote-gateway', add_comma=(values), fname='remote_gateway') - values += self.format_updated_cli_field(self.diff['after'], before, 'nattport', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'authentication_method', add_comma=(values)) - if self.diff['after']['authentication_method'] == 'rsasig': - values += self.format_updated_cli_field(self.params, before, 'certificate', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'certificate_authority', add_comma=(values)) + values += self.format_updated_cli_field( + self.diff["after"], + before, + "disabled", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "iketype", add_comma=(values) + ) + if self.diff["after"]["iketype"] != "ikev2": + values += self.format_updated_cli_field( + self.diff["after"], before, "mode", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "protocol", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "interface", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], + before, + "remote-gateway", + add_comma=(values), + fname="remote_gateway", + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "nattport", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "authentication_method", add_comma=(values) + ) + if self.diff["after"]["authentication_method"] == "rsasig": + values += self.format_updated_cli_field( + self.params, before, "certificate", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "certificate_authority", add_comma=(values) + ) else: - values += self.format_updated_cli_field(self.diff['after'], before, 'pre-shared-key', add_comma=(values), fname='preshared_key') - values += self.format_updated_cli_field(self.diff['after'], before, 'myid_type', add_comma=(values)) - id_types = ['address', 'fqdn', 'user_fqdn', 'asn1dn', 'keyid tag', 'dyn_dns'] - if self.diff['after']['myid_type'] in id_types: - values += self.format_updated_cli_field(self.diff['after'], before, 'myid_data', add_comma=(values)) - - values += self.format_updated_cli_field(self.diff['after'], before, 'peerid_type', add_comma=(values)) - if self.diff['after']['peerid_type'] in id_types: - values += self.format_updated_cli_field(self.diff['after'], before, 'peerid_data', add_comma=(values)) - - values += self.format_updated_cli_field(self.diff['after'], before, 'lifetime', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'rekey_time', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'reauth_time', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'rand_time', add_comma=(values)) - - if self.diff['after']['iketype'] == 'ikev2': - values += self.format_updated_cli_field(self.diff['after'], before, 'reauth_enable', add_comma=(values), fname='disable_reauth', - fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.diff['after'], before, 'mobike', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'splitconn', add_comma=(values), fvalue=self.fvalue_bool) - - values += self.format_updated_cli_field(self.diff['after'], before, 'gw_duplicates', add_comma=(values), fvalue=self.fvalue_bool) - - values += self.format_updated_cli_field(self.diff['after'], before, 'startaction', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'closeaction', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'nat_traversal', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'enable_dpd', add_comma=(values), fvalue=self.fvalue_bool) - if self.params['enable_dpd']: - values += self.format_updated_cli_field(self.diff['after'], before, 'dpd_delay', add_comma=(values)) - values += self.format_updated_cli_field(self.diff['after'], before, 'dpd_maxfail', add_comma=(values)) + values += self.format_updated_cli_field( + self.diff["after"], + before, + "pre-shared-key", + add_comma=(values), + fname="preshared_key", + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "myid_type", add_comma=(values) + ) + id_types = [ + "address", + "fqdn", + "user_fqdn", + "asn1dn", + "keyid tag", + "dyn_dns", + ] + if self.diff["after"]["myid_type"] in id_types: + values += self.format_updated_cli_field( + self.diff["after"], before, "myid_data", add_comma=(values) + ) + + values += self.format_updated_cli_field( + self.diff["after"], before, "peerid_type", add_comma=(values) + ) + if self.diff["after"]["peerid_type"] in id_types: + values += self.format_updated_cli_field( + self.diff["after"], before, "peerid_data", add_comma=(values) + ) + + values += self.format_updated_cli_field( + self.diff["after"], before, "lifetime", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "rekey_time", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "reauth_time", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "rand_time", add_comma=(values) + ) + + if self.diff["after"]["iketype"] == "ikev2": + values += self.format_updated_cli_field( + self.diff["after"], + before, + "reauth_enable", + add_comma=(values), + fname="disable_reauth", + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "mobike", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], + before, + "splitconn", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + + values += self.format_updated_cli_field( + self.diff["after"], + before, + "gw_duplicates", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + + values += self.format_updated_cli_field( + self.diff["after"], before, "startaction", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "closeaction", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "nat_traversal", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], + before, + "enable_dpd", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + if self.params["enable_dpd"]: + values += self.format_updated_cli_field( + self.diff["after"], before, "dpd_delay", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.diff["after"], before, "dpd_maxfail", add_comma=(values) + ) return values def _get_ref_names(self, before): - """ get cert and ca names """ - if before['caref'] is not None and before['caref'] != '': - elt = self.pfsense.find_ca_elt(before['caref'], 'refid') + """get cert and ca names""" + if before["caref"] is not None and before["caref"] != "": + elt = self.pfsense.find_ca_elt(before["caref"], "refid") if elt is not None: - before['certificate_authority'] = elt.find('descr').text + before["certificate_authority"] = elt.find("descr").text - if before['certref'] is not None and before['certref'] != '': - elt = self.pfsense.find_cert_elt(before['certref'], 'refid') + if before["certref"] is not None and before["certref"] != "": + elt = self.pfsense.find_cert_elt(before["certref"], "refid") if elt is not None: - before['certificate'] = elt.find('descr').text + before["certificate"] = elt.find("descr").text diff --git a/plugins/module_utils/ipsec_p2.py b/plugins/module_utils/ipsec_p2.py index 1f85f490..49e463a7 100644 --- a/plugins/module_utils/ipsec_p2.py +++ b/plugins/module_utils/ipsec_p2.py @@ -4,64 +4,85 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import PFSenseModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import ( + PFSenseModule, +) +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) from copy import deepcopy IPSEC_P2_ARGUMENT_SPEC = dict( - apply=dict(default=True, type='bool'), - state=dict(default='present', choices=['present', 'absent']), - descr=dict(required=True, type='str'), - p1_descr=dict(required=True, type='str'), - - disabled=dict(default=False, type='bool'), - mode=dict(choices=['tunnel', 'tunnel6', 'transport', 'vti'], type='str'), - protocol=dict(default='esp', choices=['esp', 'ah'], type='str'), - + apply=dict(default=True, type="bool"), + state=dict(default="present", choices=["present", "absent"]), + descr=dict(required=True, type="str"), + p1_descr=dict(required=True, type="str"), + disabled=dict(default=False, type="bool"), + mode=dict(choices=["tunnel", "tunnel6", "transport", "vti"], type="str"), + protocol=dict(default="esp", choices=["esp", "ah"], type="str"), # addresses - local=dict(required=False, type='str'), - nat=dict(required=False, type='str'), - remote=dict(required=False, type='str'), - + local=dict(required=False, type="str"), + nat=dict(required=False, type="str"), + remote=dict(required=False, type="str"), # encryptions - aes=dict(required=False, type='bool'), - aes128gcm=dict(required=False, type='bool'), - aes192gcm=dict(required=False, type='bool'), - aes256gcm=dict(required=False, type='bool'), - blowfish=dict(required=False, type='bool'), - des=dict(required=False, type='bool'), - cast128=dict(required=False, type='bool'), - aes_len=dict(required=False, choices=['auto', '128', '192', '256'], type='str'), - aes128gcm_len=dict(required=False, choices=['auto', '64', '96', '128'], type='str'), - aes192gcm_len=dict(required=False, choices=['auto', '64', '96', '128'], type='str'), - aes256gcm_len=dict(required=False, choices=['auto', '64', '96', '128'], type='str'), - blowfish_len=dict(required=False, choices=['auto', '128', '192', '256'], type='str'), - + aes=dict(required=False, type="bool"), + aes128gcm=dict(required=False, type="bool"), + aes192gcm=dict(required=False, type="bool"), + aes256gcm=dict(required=False, type="bool"), + blowfish=dict(required=False, type="bool"), + des=dict(required=False, type="bool"), + cast128=dict(required=False, type="bool"), + aes_len=dict(required=False, choices=["auto", "128", "192", "256"], type="str"), + aes128gcm_len=dict(required=False, choices=["auto", "64", "96", "128"], type="str"), + aes192gcm_len=dict(required=False, choices=["auto", "64", "96", "128"], type="str"), + aes256gcm_len=dict(required=False, choices=["auto", "64", "96", "128"], type="str"), + blowfish_len=dict( + required=False, choices=["auto", "128", "192", "256"], type="str" + ), # hashes - sha1=dict(required=False, type='bool'), - sha256=dict(required=False, type='bool'), - sha384=dict(required=False, type='bool'), - sha512=dict(required=False, type='bool'), - aesxcbc=dict(required=False, type='bool'), - + sha1=dict(required=False, type="bool"), + sha256=dict(required=False, type="bool"), + sha384=dict(required=False, type="bool"), + sha512=dict(required=False, type="bool"), + aesxcbc=dict(required=False, type="bool"), # misc pfsgroup=dict( - default='14', - choices=['0', '1', '2', '5', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '28', '29', '30', '31', '32'], - type='str' + default="14", + choices=[ + "0", + "1", + "2", + "5", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "28", + "29", + "30", + "31", + "32", + ], + type="str", ), - lifetime=dict(default=3600, type='int'), - pinghost=dict(required=False, type='str') + lifetime=dict(default=3600, type="int"), + pinghost=dict(required=False, type="str"), ) IPSEC_P2_REQUIRED_IF = [ ["state", "present", ["mode"]], - ["mode", "tunnel", ["local", "remote"]], ["mode", "tunnel6", ["local", "remote"]], ["mode", "vti", ["local", "remote"]], - # encryptions ["aes", True, ["aes_len"]], ["aes128gcm", True, ["aes128gcm_len"]], @@ -72,11 +93,11 @@ class PFSenseIpsecP2Module(PFSenseModuleBase): - """ module managing pfsense ipsec phase 2 options and proposals """ + """module managing pfsense ipsec phase 2 options and proposals""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return IPSEC_P2_ARGUMENT_SPEC ############################## @@ -101,111 +122,142 @@ def __init__(self, module, pfsense=None): # params processing # def _check_for_duplicate_phase2(self, phase2): - """ check for another phase2 with same remote and local """ + """check for another phase2 with same remote and local""" + def strip_phase(phase): _phase2 = {} - if phase.get('localid') is not None: - _phase2['localid'] = phase['localid'] - if phase.get('remoteid') is not None: - _phase2['remoteid'] = phase['remoteid'] + if phase.get("localid") is not None: + _phase2["localid"] = phase["localid"] + if phase.get("remoteid") is not None: + _phase2["remoteid"] = phase["remoteid"] return _phase2 _phase2 = strip_phase(phase2) - ikeid = self._phase1.find('ikeid').text + ikeid = self._phase1.find("ikeid").text for phase2_elt in self.root_elt: - if phase2_elt.tag != 'phase2': + if phase2_elt.tag != "phase2": continue - if phase2_elt.find('ikeid').text != ikeid: + if phase2_elt.find("ikeid").text != ikeid: continue - if phase2_elt.find('descr').text == phase2['descr']: + if phase2_elt.find("descr").text == phase2["descr"]: continue other_phase2 = self.pfsense.element_to_dict(phase2_elt) if _phase2 == strip_phase(other_phase2): - self.module.fail_json(msg='Phase2 with this Local/Remote networks combination is already defined for this Phase1.') + self.module.fail_json( + msg="Phase2 with this Local/Remote networks combination is already defined for this Phase1." + ) def _id_to_phase2(self, name, phase2, address, param_name): - """ setup ipsec phase2 with address """ + """setup ipsec phase2 with address""" + def set_ip_address(): - phase2[name]['type'] = 'address' - phase2[name]['address'] = address + phase2[name]["type"] = "address" + phase2[name]["address"] = address def set_ip_network(): - phase2[name]['type'] = 'network' - (phase2[name]['address'], phase2[name]['netbits']) = self.pfsense.parse_ip_network(address, False) - phase2[name]['netbits'] = str(phase2[name]['netbits']) + phase2[name]["type"] = "network" + (phase2[name]["address"], phase2[name]["netbits"]) = ( + self.pfsense.parse_ip_network(address, False) + ) + phase2[name]["netbits"] = str(phase2[name]["netbits"]) + phase2[name] = dict() - interface = self.pfsense.parse_interface(address, fail=False, with_virtual=False) + interface = self.pfsense.parse_interface( + address, fail=False, with_virtual=False + ) if interface is not None: - if phase2['mode'] == 'vti': - msg = 'VTI requires a valid local network or IP address for its endpoint address.' + if phase2["mode"] == "vti": + msg = "VTI requires a valid local network or IP address for its endpoint address." self.module.fail_json(msg=msg) - phase2[name]['type'] = interface + phase2[name]["type"] = interface elif self.pfsense.is_ipv4_address(address): - if self.params['mode'] == 'tunnel6': - self.module.fail_json(msg='A valid IPv6 address or network must be specified in {0} with tunnel6.'.format(param_name)) + if self.params["mode"] == "tunnel6": + self.module.fail_json( + msg="A valid IPv6 address or network must be specified in {0} with tunnel6.".format( + param_name + ) + ) set_ip_address() elif self.pfsense.is_ipv6_address(address): - if self.params['mode'] == 'tunnel': - self.module.fail_json(msg='A valid IPv4 address or network must be specified in {0} with tunnel.'.format(param_name)) + if self.params["mode"] == "tunnel": + self.module.fail_json( + msg="A valid IPv4 address or network must be specified in {0} with tunnel.".format( + param_name + ) + ) set_ip_address() elif self.pfsense.is_ipv4_network(address, False): - if self.params['mode'] == 'tunnel6': - self.module.fail_json(msg='A valid IPv6 address or network must be specified in {0} with tunnel6.'.format(param_name)) + if self.params["mode"] == "tunnel6": + self.module.fail_json( + msg="A valid IPv6 address or network must be specified in {0} with tunnel6.".format( + param_name + ) + ) set_ip_network() elif self.pfsense.is_ipv6_network(address, False): - if self.params['mode'] == 'tunnel': - self.module.fail_json(msg='A valid IPv4 address or network must be specified in {0} with tunnel.'.format(param_name)) + if self.params["mode"] == "tunnel": + self.module.fail_json( + msg="A valid IPv4 address or network must be specified in {0} with tunnel.".format( + param_name + ) + ) set_ip_network() else: - self.module.fail_json(msg='A valid IP address, network or interface must be specified in {0}.'.format(param_name)) + self.module.fail_json( + msg="A valid IP address, network or interface must be specified in {0}.".format( + param_name + ) + ) def _params_to_obj(self): - """ return an phase2 dict from module params """ + """return an phase2 dict from module params""" params = self.params obj = dict() - obj['descr'] = params['descr'] - self.apply = params['apply'] - - if params['state'] == 'present': - obj['mode'] = params['mode'] - if obj['mode'] != 'transport': - - if obj['mode'] == 'vti' and not self.pfsense.is_ipv4_address(params['remote']): - msg = 'VTI requires a valid remote IP address for its endpoint address.' + obj["descr"] = params["descr"] + self.apply = params["apply"] + + if params["state"] == "present": + obj["mode"] = params["mode"] + if obj["mode"] != "transport": + if obj["mode"] == "vti" and not self.pfsense.is_ipv4_address( + params["remote"] + ): + msg = "VTI requires a valid remote IP address for its endpoint address." self.module.fail_json(msg=msg) - self._id_to_phase2('localid', obj, params['local'], 'local') - self._id_to_phase2('remoteid', obj, params['remote'], 'remote') + self._id_to_phase2("localid", obj, params["local"], "local") + self._id_to_phase2("remoteid", obj, params["remote"], "remote") - if obj['mode'] != 'vti' and params.get('nat') is not None: - self._id_to_phase2('natlocalid', obj, params['nat'], 'nat') + if obj["mode"] != "vti" and params.get("nat") is not None: + self._id_to_phase2("natlocalid", obj, params["nat"], "nat") - if params.get('disabled'): - obj['disabled'] = '' + if params.get("disabled"): + obj["disabled"] = "" - obj['protocol'] = params['protocol'] - obj['pfsgroup'] = params['pfsgroup'] - if params.get('lifetime') is not None and params['lifetime'] > 0: - obj['lifetime'] = str(params['lifetime']) + obj["protocol"] = params["protocol"] + obj["pfsgroup"] = params["pfsgroup"] + if params.get("lifetime") is not None and params["lifetime"] > 0: + obj["lifetime"] = str(params["lifetime"]) else: - obj['lifetime'] = '' + obj["lifetime"] = "" - if obj.get('pinghost'): - obj['pinghost'] = params['pinghost'] + if obj.get("pinghost"): + obj["pinghost"] = params["pinghost"] else: - obj['pinghost'] = '' + obj["pinghost"] = "" self._check_for_duplicate_phase2(obj) return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" + def has_one_of(bools): for name in bools: if params.get(name): @@ -215,40 +267,58 @@ def has_one_of(bools): params = self.params # called from ipsec_aggregate - if params.get('ikeid') is not None: - self._phase1 = self.pfsense.find_ipsec_phase1(params['ikeid'], 'ikeid') + if params.get("ikeid") is not None: + self._phase1 = self.pfsense.find_ipsec_phase1(params["ikeid"], "ikeid") if self._phase1 is None: - self.module.fail_json(msg='No ipsec tunnel with ikeid {0}'.format(params['ikeid'])) + self.module.fail_json( + msg="No ipsec tunnel with ikeid {0}".format(params["ikeid"]) + ) else: - self._phase1 = self.pfsense.find_ipsec_phase1(params['p1_descr']) + self._phase1 = self.pfsense.find_ipsec_phase1(params["p1_descr"]) if self._phase1 is None: - self.module.fail_json(msg='No ipsec tunnel named {0}'.format(params['p1_descr'])) - - if params['state'] == 'present': - encs = ['aes', 'aes128gcm', 'aes192gcm', 'aes256gcm', 'blowfish', 'des', 'cast128'] - if params['protocol'] == 'esp' and not has_one_of(encs): - self.module.fail_json(msg='At least one encryption algorithm must be selected.') + self.module.fail_json( + msg="No ipsec tunnel named {0}".format(params["p1_descr"]) + ) + + if params["state"] == "present": + encs = [ + "aes", + "aes128gcm", + "aes192gcm", + "aes256gcm", + "blowfish", + "des", + "cast128", + ] + if params["protocol"] == "esp" and not has_one_of(encs): + self.module.fail_json( + msg="At least one encryption algorithm must be selected." + ) if self.pfsense.is_at_least_2_5_0(): - need_one_hash = has_one_of(['aes', 'blowfish', 'des', 'cast128']) + need_one_hash = has_one_of(["aes", "blowfish", "des", "cast128"]) else: need_one_hash = True - if need_one_hash and not has_one_of(['sha1', 'sha256', 'sha384', 'sha512', 'aesxcbc']): - self.module.fail_json(msg='At least one hashing algorithm needs to be selected.') + if need_one_hash and not has_one_of( + ["sha1", "sha256", "sha384", "sha512", "aesxcbc"] + ): + self.module.fail_json( + msg="At least one hashing algorithm needs to be selected." + ) ############################## # XML processing # def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) self._sync_encryptions(self.target_elt) self._sync_hashes(self.target_elt) self.root_elt.append(self.target_elt) def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" self.before_elt = deepcopy(self.target_elt) before = self.pfsense.element_to_dict(self.target_elt) changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) @@ -265,22 +335,22 @@ def _copy_and_update_target(self): return (before, changed) def _create_target(self): - """ create the XML target_elt """ - target_elt = self.pfsense.new_element('phase2') - self.obj['ikeid'] = self._phase1.find('ikeid').text - self.obj['uniqid'] = self.pfsense.uniqid() - self.obj['reqid'] = str(self._find_free_reqid()) + """create the XML target_elt""" + target_elt = self.pfsense.new_element("phase2") + self.obj["ikeid"] = self._phase1.find("ikeid").text + self.obj["uniqid"] = self.pfsense.uniqid() + self.obj["reqid"] = str(self._find_free_reqid()) return target_elt def _find_free_reqid(self): - """ return first unused reqid """ + """return first unused reqid""" reqid = 1 while True: found = False for phase2_elt in self.root_elt: - if phase2_elt.tag != 'phase2': + if phase2_elt.tag != "phase2": continue - reqid_elt = phase2_elt.find('reqid') + reqid_elt = phase2_elt.find("reqid") if reqid_elt is not None and reqid_elt.text == str(reqid): found = True break @@ -290,49 +360,55 @@ def _find_free_reqid(self): reqid = reqid + 1 def _find_target(self): - """ return ipsec phase2 elt if found """ - ikeid = self._phase1.find('ikeid').text + """return ipsec phase2 elt if found""" + ikeid = self._phase1.find("ikeid").text for phase2_elt in self.root_elt: - if phase2_elt.tag != 'phase2': + if phase2_elt.tag != "phase2": continue - if phase2_elt.find('ikeid').text != ikeid: + if phase2_elt.find("ikeid").text != ikeid: continue - descr_elt = phase2_elt.find('descr') - if descr_elt is not None and descr_elt.text == self.obj['descr']: + descr_elt = phase2_elt.find("descr") + if descr_elt is not None and descr_elt.text == self.obj["descr"]: return phase2_elt return None def _remove_deleted_ipsec_params(self): - """ Remove from phase2 a few deleted params """ + """Remove from phase2 a few deleted params""" changed = False - params = ['disabled'] + params = ["disabled"] for param in params: - if self.pfsense.remove_deleted_param_from_elt(self.target_elt, param, self.obj): + if self.pfsense.remove_deleted_param_from_elt( + self.target_elt, param, self.obj + ): changed = True - for param in ['localid', 'remoteid', 'natlocalid']: + for param in ["localid", "remoteid", "natlocalid"]: if self._remove_extra_deleted_ipsec_params(param): changed = True return changed def _remove_extra_deleted_ipsec_params(self, name): - """ Remove from phase2 a few extra deleted params """ + """Remove from phase2 a few extra deleted params""" changed = False - params = ['type', 'address', 'netbits'] + params = ["type", "address", "netbits"] sub_elt = self.target_elt.find(name) if sub_elt is not None: for param in params: if name in self.obj: - if self.pfsense.remove_deleted_param_from_elt(sub_elt, param, self.obj[name]): + if self.pfsense.remove_deleted_param_from_elt( + sub_elt, param, self.obj[name] + ): changed = True else: - if self.pfsense.remove_deleted_param_from_elt(sub_elt, param, dict()): + if self.pfsense.remove_deleted_param_from_elt( + sub_elt, param, dict() + ): changed = True if len(sub_elt) == 0: @@ -341,10 +417,11 @@ def _remove_extra_deleted_ipsec_params(self, name): return changed def _sync_encryptions(self, phase2_elt): - """ sync encryptions params """ + """sync encryptions params""" + def get_encryption(encryptions_elt, name): for encryption_elt in encryptions_elt: - name_elt = encryption_elt.find('name') + name_elt = encryption_elt.find("name") if name_elt is not None and name_elt.text == name: return encryption_elt return None @@ -353,11 +430,13 @@ def sync_encryption(encryptions_elt, name, param_name): encryption_elt = get_encryption(encryptions_elt, name) if self.params.get(param_name): encryption = dict() - encryption['name'] = name - if self.params.get(param_name + '_len') is not None: - encryption['keylen'] = self.params[param_name + '_len'] + encryption["name"] = name + if self.params.get(param_name + "_len") is not None: + encryption["keylen"] = self.params[param_name + "_len"] if encryption_elt is None: - encryption_elt = self.pfsense.new_element('encryption-algorithm-option') + encryption_elt = self.pfsense.new_element( + "encryption-algorithm-option" + ) self.pfsense.copy_dict_to_element(encryption, encryption_elt) phase2_elt.append(encryption_elt) return True @@ -373,25 +452,26 @@ def sync_encryption(encryptions_elt, name, param_name): return False changed = False - encryptions_elt = phase2_elt.findall('encryption-algorithm-option') - if sync_encryption(encryptions_elt, 'aes', 'aes'): + encryptions_elt = phase2_elt.findall("encryption-algorithm-option") + if sync_encryption(encryptions_elt, "aes", "aes"): changed = True - if sync_encryption(encryptions_elt, 'aes128gcm', 'aes128gcm'): + if sync_encryption(encryptions_elt, "aes128gcm", "aes128gcm"): changed = True - if sync_encryption(encryptions_elt, 'aes192gcm', 'aes192gcm'): + if sync_encryption(encryptions_elt, "aes192gcm", "aes192gcm"): changed = True - if sync_encryption(encryptions_elt, 'aes256gcm', 'aes256gcm'): + if sync_encryption(encryptions_elt, "aes256gcm", "aes256gcm"): changed = True - if sync_encryption(encryptions_elt, 'blowfish', 'blowfish'): + if sync_encryption(encryptions_elt, "blowfish", "blowfish"): changed = True - if sync_encryption(encryptions_elt, '3des', 'des'): + if sync_encryption(encryptions_elt, "3des", "des"): changed = True - if sync_encryption(encryptions_elt, 'cast128', 'cast128'): + if sync_encryption(encryptions_elt, "cast128", "cast128"): changed = True return changed def _sync_hashes(self, phase2_elt): - """ sync hashes params """ + """sync hashes params""" + def get_hash(hashes_elt, name): for hash_elt in hashes_elt: if hash_elt.text == name: @@ -401,7 +481,7 @@ def get_hash(hashes_elt, name): def sync_hash(hashes_elt, name, param_name): if self.params.get(param_name) is True: if get_hash(hashes_elt, name) is None: - hash_elt = self.pfsense.new_element('hash-algorithm-option') + hash_elt = self.pfsense.new_element("hash-algorithm-option") hash_elt.text = name phase2_elt.append(hash_elt) return True @@ -413,16 +493,16 @@ def sync_hash(hashes_elt, name, param_name): return False changed = False - hashes_elt = phase2_elt.findall('hash-algorithm-option') - if sync_hash(hashes_elt, 'hmac_sha1', 'sha1'): + hashes_elt = phase2_elt.findall("hash-algorithm-option") + if sync_hash(hashes_elt, "hmac_sha1", "sha1"): changed = True - if sync_hash(hashes_elt, 'hmac_sha256', 'sha256'): + if sync_hash(hashes_elt, "hmac_sha256", "sha256"): changed = True - if sync_hash(hashes_elt, 'hmac_sha384', 'sha384'): + if sync_hash(hashes_elt, "hmac_sha384", "sha384"): changed = True - if sync_hash(hashes_elt, 'hmac_sha512', 'sha512'): + if sync_hash(hashes_elt, "hmac_sha512", "sha512"): changed = True - if sync_hash(hashes_elt, 'aesxcbc', 'aesxcbc'): + if sync_hash(hashes_elt, "aesxcbc", "aesxcbc"): changed = True return changed @@ -436,118 +516,236 @@ def _update(self): # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}' on '{1}'".format(self.obj['descr'], self.params['p1_descr']) + """return obj's name""" + return "'{0}' on '{1}'".format(self.obj["descr"], self.params["p1_descr"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ + """generate pseudo-CLI command fields parameters to create an obj""" + def log_enc(name): - log = '' + log = "" log += self.format_cli_field(self.params, name, fvalue=self.fvalue_bool) - if self.params.get(name) and self.params.get(name + '_len') is not None: - log += self.format_cli_field(self.params, name + '_len') + if self.params.get(name) and self.params.get(name + "_len") is not None: + log += self.format_cli_field(self.params, name + "_len") return log - values = '' + + values = "" if before is None: - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.obj, 'mode') - - values += self.format_cli_field(self.params, 'local') - values += self.format_cli_field(self.params, 'remote') - values += self.format_cli_field(self.params, 'nat') - - values += log_enc('aes') - values += log_enc('aes128gcm') - values += log_enc('aes192gcm') - values += log_enc('aes256gcm') - values += log_enc('blowfish') - values += log_enc('des') - values += log_enc('cast128') - - values += self.format_cli_field(self.params, 'sha1', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'sha256', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'sha384', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'sha512', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'aesxcbc', fvalue=self.fvalue_bool) - - values += self.format_cli_field(self.params, 'pfsgroup') - values += self.format_cli_field(self.params, 'lifetime') - values += self.format_cli_field(self.params, 'pinghost') + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool + ) + values += self.format_cli_field(self.obj, "mode") + + values += self.format_cli_field(self.params, "local") + values += self.format_cli_field(self.params, "remote") + values += self.format_cli_field(self.params, "nat") + + values += log_enc("aes") + values += log_enc("aes128gcm") + values += log_enc("aes192gcm") + values += log_enc("aes256gcm") + values += log_enc("blowfish") + values += log_enc("des") + values += log_enc("cast128") + + values += self.format_cli_field( + self.params, "sha1", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.params, "sha256", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.params, "sha384", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.params, "sha512", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.params, "aesxcbc", fvalue=self.fvalue_bool + ) + + values += self.format_cli_field(self.params, "pfsgroup") + values += self.format_cli_field(self.params, "lifetime") + values += self.format_cli_field(self.params, "pinghost") else: - self._prepare_log_address(before, 'local', 'localid') - self._prepare_log_address(before, 'nat', 'natlocalid') - self._prepare_log_address(before, 'remote', 'remoteid') + self._prepare_log_address(before, "local", "localid") + self._prepare_log_address(before, "nat", "natlocalid") + self._prepare_log_address(before, "remote", "remoteid") self._prepare_log_encryptions(before, self.before_elt) self._prepare_log_hashes(before, self.before_elt) - values += self.format_updated_cli_field(self.obj, before, 'disabled', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.obj, before, 'mode', add_comma=(values)) - - values += self.format_updated_cli_field(self.params, before, 'local', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'remote', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'nat', add_comma=(values)) - - values += self.format_updated_cli_field(self.params, before, 'aes', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'aes_len', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'aes128gcm', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'aes128gcm_len', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'aes192gcm', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'aes192gcm_len', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'aes256gcm', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'aes256gcm_len', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'blowfish', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'blowfish_len', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'des', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'cast128', add_comma=(values), fvalue=self.fvalue_bool) - - values += self.format_updated_cli_field(self.params, before, 'sha1', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'sha256', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'sha384', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'sha512', add_comma=(values), fvalue=self.fvalue_bool) - values += self.format_updated_cli_field(self.params, before, 'aesxcbc', add_comma=(values), fvalue=self.fvalue_bool) - - values += self.format_updated_cli_field(self.obj, before, 'pfsgroup', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'lifetime', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'pinghost', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, + before, + "disabled", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.obj, before, "mode", add_comma=(values) + ) + + values += self.format_updated_cli_field( + self.params, before, "local", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "remote", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "nat", add_comma=(values) + ) + + values += self.format_updated_cli_field( + self.params, before, "aes", add_comma=(values), fvalue=self.fvalue_bool + ) + values += self.format_updated_cli_field( + self.params, before, "aes_len", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, + before, + "aes128gcm", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, before, "aes128gcm_len", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, + before, + "aes192gcm", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, before, "aes192gcm_len", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, + before, + "aes256gcm", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, before, "aes256gcm_len", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, + before, + "blowfish", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, before, "blowfish_len", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "des", add_comma=(values), fvalue=self.fvalue_bool + ) + values += self.format_updated_cli_field( + self.params, + before, + "cast128", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + + values += self.format_updated_cli_field( + self.params, before, "sha1", add_comma=(values), fvalue=self.fvalue_bool + ) + values += self.format_updated_cli_field( + self.params, + before, + "sha256", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, + before, + "sha384", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, + before, + "sha512", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + values += self.format_updated_cli_field( + self.params, + before, + "aesxcbc", + add_comma=(values), + fvalue=self.fvalue_bool, + ) + + values += self.format_updated_cli_field( + self.obj, before, "pfsgroup", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "lifetime", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "pinghost", add_comma=(values) + ) return values def _prepare_log_address(self, before, param, name): - """ reparse some params for logging """ - if before.get(name) is None or not isinstance(before[name], dict) or before[name].get('type') is None: + """reparse some params for logging""" + if ( + before.get(name) is None + or not isinstance(before[name], dict) + or before[name].get("type") is None + ): before[param] = None return - if before[name]['type'] == 'address': - before[param] = before[name]['address'] - elif before[name]['type'] == 'network': - before[param] = before[name]['address'] + '/' + str(before[name]['netbits']) + if before[name]["type"] == "address": + before[param] = before[name]["address"] + elif before[name]["type"] == "network": + before[param] = before[name]["address"] + "/" + str(before[name]["netbits"]) else: - before[param] = self.pfsense.get_interface_display_name(before[name]['type']) + before[param] = self.pfsense.get_interface_display_name( + before[name]["type"] + ) @staticmethod def _prepare_log_encryptions(before, before_elt): - """ reparse some params for logging """ - encryptions_elt = before_elt.findall('encryption-algorithm-option') + """reparse some params for logging""" + encryptions_elt = before_elt.findall("encryption-algorithm-option") for encryption_elt in encryptions_elt: - name = encryption_elt.find('name').text - len_elt = encryption_elt.find('keylen') - if name == '3des': - name = 'des' + name = encryption_elt.find("name").text + len_elt = encryption_elt.find("keylen") + if name == "3des": + name = "des" before[name] = True if len_elt is not None: - before[name + '_len'] = len_elt.text - - encs = ['aes', 'aes128gcm', 'aes192gcm', 'aes256gcm', 'blowfish', 'des', 'cast128'] + before[name + "_len"] = len_elt.text + + encs = [ + "aes", + "aes128gcm", + "aes192gcm", + "aes256gcm", + "blowfish", + "des", + "cast128", + ] for enc in encs: if enc not in before.keys(): before[enc] = False - if enc + '_len' not in before.keys(): - before[enc + '_len'] = None + if enc + "_len" not in before.keys(): + before[enc + "_len"] = None @staticmethod def _prepare_log_hashes(before, before_elt): - """ reparse some params for logging """ - hashes_elt = before_elt.findall('hash-algorithm-option') + """reparse some params for logging""" + hashes_elt = before_elt.findall("hash-algorithm-option") for hash_elt in hashes_elt: name = hash_elt.text.replace("hmac_", "") before[name] = True diff --git a/plugins/module_utils/ipsec_proposal.py b/plugins/module_utils/ipsec_proposal.py index 729dc811..e1dafa54 100644 --- a/plugins/module_utils/ipsec_proposal.py +++ b/plugins/module_utils/ipsec_proposal.py @@ -4,20 +4,67 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) from copy import deepcopy IPSEC_PROPOSAL_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - descr=dict(required=False, type='str'), - encryption=dict(required=True, choices=['aes', 'aes128gcm', 'aes192gcm', 'aes256gcm', 'blowfish', '3des', 'cast128'], type='str'), - key_length=dict(required=False, choices=[64, 96, 128, 192, 256], type='int'), - hash=dict(required=True, choices=['md5', 'sha1', 'sha256', 'sha384', 'sha512', 'aesxcbc'], type='str'), - prf=dict(required=False, choices=['md5', 'sha1', 'sha256', 'sha384', 'sha512', 'aesxcbc'], type='str'), - dhgroup=dict(required=True, choices=[1, 2, 5, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 28, 29, 30, 31, 32], type='int'), - apply=dict(default=True, type='bool'), + state=dict(default="present", choices=["present", "absent"]), + descr=dict(required=False, type="str"), + encryption=dict( + required=True, + choices=[ + "aes", + "aes128gcm", + "aes192gcm", + "aes256gcm", + "blowfish", + "3des", + "cast128", + ], + type="str", + ), + key_length=dict(required=False, choices=[64, 96, 128, 192, 256], type="int"), + hash=dict( + required=True, + choices=["md5", "sha1", "sha256", "sha384", "sha512", "aesxcbc"], + type="str", + ), + prf=dict( + required=False, + choices=["md5", "sha1", "sha256", "sha384", "sha512", "aesxcbc"], + type="str", + ), + dhgroup=dict( + required=True, + choices=[ + 1, + 2, + 5, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 28, + 29, + 30, + 31, + 32, + ], + type="int", + ), + apply=dict(default=True, type="bool"), ) IPSEC_PROPOSAL_REQUIRED_IF = [ @@ -30,11 +77,11 @@ class PFSenseIpsecProposalModule(PFSenseModuleBase): - """ module managing pfsense ipsec phase 1 proposals """ + """module managing pfsense ipsec phase 1 proposals""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return IPSEC_PROPOSAL_ARGUMENT_SPEC ############################## @@ -55,93 +102,112 @@ def __init__(self, module, pfsense=None): # def _onward_params(self): return [ - ['prf', self.pfsense.is_at_least_2_5_0], + ["prf", self.pfsense.is_at_least_2_5_0], ] def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() - obj['encryption-algorithm'] = dict() - obj['encryption-algorithm']['name'] = params['encryption'] - if params.get('key_length') is not None: - obj['encryption-algorithm']['keylen'] = str(params['key_length']) + obj["encryption-algorithm"] = dict() + obj["encryption-algorithm"]["name"] = params["encryption"] + if params.get("key_length") is not None: + obj["encryption-algorithm"]["keylen"] = str(params["key_length"]) else: - obj['encryption-algorithm']['keylen'] = '' - obj['hash-algorithm'] = params['hash'] - obj['dhgroup'] = str(params['dhgroup']) + obj["encryption-algorithm"]["keylen"] = "" + obj["hash-algorithm"] = params["hash"] + obj["dhgroup"] = str(params["dhgroup"]) if self.pfsense.is_at_least_2_5_0(): - if params.get('prf') is not None: - obj['prf-algorithm'] = params['prf'] + if params.get("prf") is not None: + obj["prf-algorithm"] = params["prf"] else: - obj['prf-algorithm'] = 'sha256' + obj["prf-algorithm"] = "sha256" - self.apply = params['apply'] + self.apply = params["apply"] return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params key_length = dict() - key_length['aes'] = ['128', '192', '256'] - key_length['aes192gcm'] = ['64', '96', '128'] - key_length['aes128gcm'] = ['64', '96', '128'] - key_length['aes256gcm'] = ['64', '96', '128'] - key_length['blowfish'] = ['128', '192', '256'] - if params['encryption'] in key_length.keys() and str(params['key_length']) not in key_length[params['encryption']]: - msg = 'key_length for encryption {0} must be one of: {1}.'.format(params['encryption'], ', '.join(key_length[params['encryption']])) + key_length["aes"] = ["128", "192", "256"] + key_length["aes192gcm"] = ["64", "96", "128"] + key_length["aes128gcm"] = ["64", "96", "128"] + key_length["aes256gcm"] = ["64", "96", "128"] + key_length["blowfish"] = ["128", "192", "256"] + if ( + params["encryption"] in key_length.keys() + and str(params["key_length"]) not in key_length[params["encryption"]] + ): + msg = "key_length for encryption {0} must be one of: {1}.".format( + params["encryption"], ", ".join(key_length[params["encryption"]]) + ) self.module.fail_json(msg=msg) # called from ipsec_aggregate - if params.get('ikeid') is not None: - self._phase1 = self.pfsense.find_ipsec_phase1(params['ikeid'], 'ikeid') + if params.get("ikeid") is not None: + self._phase1 = self.pfsense.find_ipsec_phase1(params["ikeid"], "ikeid") if self._phase1 is None: - self.module.fail_json(msg='No ipsec tunnel with ikeid {0}'.format(params['ikeid'])) + self.module.fail_json( + msg="No ipsec tunnel with ikeid {0}".format(params["ikeid"]) + ) else: - self._phase1 = self.pfsense.find_ipsec_phase1(params['descr']) + self._phase1 = self.pfsense.find_ipsec_phase1(params["descr"]) if self._phase1 is None: - self.module.fail_json(msg='No ipsec tunnel named {0}'.format(params['descr'])) + self.module.fail_json( + msg="No ipsec tunnel named {0}".format(params["descr"]) + ) - self.root_elt = self._phase1.find('encryption') + self.root_elt = self._phase1.find("encryption") if self.root_elt is None: - self.root_elt = self.pfsense.new_element('encryption') + self.root_elt = self.pfsense.new_element("encryption") self._phase1.append(self.root_elt) - if params['encryption'] in ['aes128gcm', 'aes192gcm', 'aes256gcm']: - iketype_elt = self._phase1.find('iketype') - if iketype_elt is not None and iketype_elt.text != 'ikev2': - self.module.fail_json(msg='Encryption Algorithm AES-GCM can only be used with IKEv2') + if params["encryption"] in ["aes128gcm", "aes192gcm", "aes256gcm"]: + iketype_elt = self._phase1.find("iketype") + if iketype_elt is not None and iketype_elt.text != "ikev2": + self.module.fail_json( + msg="Encryption Algorithm AES-GCM can only be used with IKEv2" + ) ############################## # XML processing # @staticmethod def _copy_and_update_target(): - """ update the XML target_elt """ + """update the XML target_elt""" return (None, False) def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('item') + """create the XML target_elt""" + return self.pfsense.new_element("item") def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" # 2.5.0: when deleting, if prf is not specified we're taking the first matching proposal without taking prf into account - if self.params['state'] == 'absent' and self.params.get('prf') is None and self.pfsense.is_at_least_2_5_0(): + if ( + self.params["state"] == "absent" + and self.params.get("prf") is None + and self.pfsense.is_at_least_2_5_0() + ): obj = deepcopy(self.obj) - obj.pop('prf-algorithm', None) + obj.pop("prf-algorithm", None) else: obj = self.obj - items_elt = self.root_elt.findall('item') + items_elt = self.root_elt.findall("item") for item in items_elt: existing = self.pfsense.element_to_dict(item) - if self.params['state'] == 'absent' and self.params.get('prf') is None and self.pfsense.is_at_least_2_5_0(): - existing.pop('prf-algorithm', None) + if ( + self.params["state"] == "absent" + and self.params.get("prf") is None + and self.pfsense.is_at_least_2_5_0() + ): + existing.pop("prf-algorithm", None) if existing == obj: return item return None @@ -150,27 +216,27 @@ def _find_target(self): # run # def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" return self.pfsense.apply_ipsec_changes() ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.params['descr']) + """return obj's name""" + return "'{0}'".format(self.params["descr"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' - values += self.format_cli_field(self.params, 'encryption') - values += self.format_cli_field(self.params, 'key_length') - values += self.format_cli_field(self.obj, 'hash-algorithm', fname='hash') - values += self.format_cli_field(self.obj, 'dhgroup') + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" + values += self.format_cli_field(self.params, "encryption") + values += self.format_cli_field(self.params, "key_length") + values += self.format_cli_field(self.obj, "hash-algorithm", fname="hash") + values += self.format_cli_field(self.obj, "dhgroup") if self.pfsense.is_at_least_2_5_0(): - values += self.format_cli_field(self.obj, 'prf-algorithm', fname='prf') + values += self.format_cli_field(self.obj, "prf-algorithm", fname="prf") return values def _log_fields_delete(self): - """ generate pseudo-CLI command fields parameters to delete an obj """ + """generate pseudo-CLI command fields parameters to delete an obj""" return self._log_fields() diff --git a/plugins/module_utils/module_base.py b/plugins/module_utils/module_base.py index 4ea4a0e2..9c08f24c 100644 --- a/plugins/module_utils/module_base.py +++ b/plugins/module_utils/module_base.py @@ -4,14 +4,21 @@ # Copyright: (c) 2024, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import PFSenseModule -from ansible_collections.pfsensible.core.plugins.module_utils.arg_route import p2o_interface +from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import ( + PFSenseModule, +) +from ansible_collections.pfsensible.core.plugins.module_utils.arg_route import ( + p2o_interface, +) BASE_ARG_ROUTE = dict( - interface=dict(parse=p2o_interface,), + interface=dict( + parse=p2o_interface, + ), ) @@ -33,7 +40,7 @@ def merge_dicts(a: dict, b: dict, path=None): # Move a key in dict to a new one, allowing the use '/' to specify nested dict location def move_dict_key(obj, src, dst): item = None - for n in reversed(dst.split('/')): + for n in reversed(dst.split("/")): if item is None: item = dict() item[n] = obj[src] @@ -46,7 +53,7 @@ def move_dict_key(obj, src, dst): class PFSenseModuleBase(object): - """ class providing base services for pfSense modules """ + """class providing base services for pfSense modules""" ############################## # unit tests @@ -54,51 +61,78 @@ class PFSenseModuleBase(object): # Must be class method for unit test usage @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" raise NotImplementedError() ############################## # init # - def __init__(self, module, pfsense=None, package=None, name=None, root=None, root_is_exclusive=True, create_root=False, node=None, key='descr', - update_php=None, arg_route=None, map_param=None, map_param_if=None, param_force=None, bool_style=None, bool_values=None, have_refid=False, - create_default=None): - self.module = module # ansible module - self.argument_spec = module.argument_spec # Allow for being overriden for use with aggregate + def __init__( + self, + module, + pfsense=None, + package=None, + name=None, + root=None, + root_is_exclusive=True, + create_root=False, + node=None, + key="descr", + update_php=None, + arg_route=None, + map_param=None, + map_param_if=None, + param_force=None, + bool_style=None, + bool_values=None, + have_refid=False, + create_default=None, + ): + self.module = module # ansible module + self.argument_spec = ( + module.argument_spec + ) # Allow for being overriden for use with aggregate # pfSense helper module if pfsense is None: pfsense = PFSenseModule(module) self.pfsense = pfsense - if name is not None: # ansible module name + if name is not None: # ansible module name self.name = name elif node is not None: - self.name = 'pfsense_' + node + self.name = "pfsense_" + node else: self.name = None - self.apply = True # apply configuration at the end + self.apply = True # apply configuration at the end # xml parent of target_elt, node named by root # TODO - handle paths with creation - e.g. if root is not None: - if root == 'pfsense': + if root == "pfsense": self.root_elt = self.pfsense.root self.root_is_exclusive = False else: if package is not None: - self.root_elt = self.pfsense.get_element(root, root_elt=self.pfsense.root.find('installedpackages')) + self.root_elt = self.pfsense.get_element( + root, root_elt=self.pfsense.root.find("installedpackages") + ) if self.root_elt is None: self.module.fail_json( - msg='Unable to find configuration for the package {package}. Are you sure that it is installed?'.format(package=package)) + msg="Unable to find configuration for the package {package}. Are you sure that it is installed?".format( + package=package + ) + ) else: root_elt = self.pfsense.root - for this in root.split('/'): - root_elt = self.pfsense.get_element(this, root_elt=root_elt, create_node=create_root) + for this in root.split("/"): + root_elt = self.pfsense.get_element( + this, root_elt=root_elt, create_node=create_root + ) self.root_elt = root_elt - if root in ['system']: + if root in ["system"]: self.root_is_exclusive = False else: self.root_is_exclusive = root_is_exclusive @@ -114,8 +148,8 @@ def __init__(self, module, pfsense=None, package=None, name=None, root=None, roo self.elements = None self.node = node - self.key = key # item that identifies a target element - self.obj = dict() # dict holding target pfsense parameters + self.key = key # item that identifies a target element + self.obj = dict() # dict holding target pfsense parameters # routing for argument handling self.arg_route = BASE_ARG_ROUTE @@ -138,31 +172,40 @@ def __init__(self, module, pfsense=None, package=None, name=None, root=None, roo self.param_force = param_force # parameters that are forced to be present else: self.param_force = list() - self.bool_style = bool_style # default boolean value style for arguments + self.bool_style = bool_style # default boolean value style for arguments if bool_values is not None: - self.bool_values = bool_values # boolean values for specific arguments + self.bool_values = bool_values # boolean values for specific arguments else: self.bool_values = dict() self.create_default = create_default # default values for a created target - self.have_refid = have_refid # if the element has a refid item + self.have_refid = have_refid # if the element has a refid item self.target_elt = None # xml object holding target pfsense parameters self.update_php = update_php # php code to update configuration - self.change_descr = '' + self.change_descr = "" self.result = {} - self.result['changed'] = False - self.result['commands'] = [] + self.result["changed"] = False + self.result["commands"] = [] - self.diff = {'after': {}, 'before': {}} - self.result['diff'] = self.diff + self.diff = {"after": {}, "before": {}} + self.result["diff"] = self.diff ############################## # params processing # - def _get_ansible_param(self, obj, name, fname=None, force=False, exclude=None, force_value='', params=None): - """ get parameter from params and set it into obj """ + def _get_ansible_param( + self, + obj, + name, + fname=None, + force=False, + exclude=None, + force_value="", + params=None, + ): + """get parameter from params and set it into obj""" if fname is None: fname = name if params is None: @@ -177,8 +220,17 @@ def _get_ansible_param(self, obj, name, fname=None, force=False, exclude=None, f elif force: obj[fname] = force_value - def _get_ansible_param_bool(self, obj, name, fname=None, force=False, value='yes', value_false=None, params=None): - """ get bool parameter from params and set it into obj """ + def _get_ansible_param_bool( + self, + obj, + name, + fname=None, + force=False, + value="yes", + value_false=None, + params=None, + ): + """get bool parameter from params and set it into obj""" if fname is None: fname = name if params is None: @@ -197,25 +249,31 @@ def _get_ansible_param_bool(self, obj, name, fname=None, force=False, value='yes obj[fname] = value_false def _params_to_obj(self, obj=None): - """ return a dict from module params that sets self.obj """ + """return a dict from module params that sets self.obj""" if obj is None: obj = dict() # Not all modules have 'state', treat them like they did - if self.params.get('state', 'present') == 'present': + if self.params.get("state", "present") == "present": # Skip 'state', but otherwise process all parameters. Ansible sets unspecified parameters to None. - for param in [p for p in self.params if p != 'state']: + for param in [p for p in self.params if p != "state"]: force = False if param in self.param_force: force = True # If we have defined a parser for this arg, use it - if param in self.arg_route and 'parse' in self.arg_route[param] and self.params.get(param) is not None: - self.arg_route[param]['parse'](self, param, self.params, obj) - elif self.argument_spec[param].get('type') == 'bool': + if param in self.arg_route and "parse" in self.arg_route[param] and self.params.get(param) is not None: + self.arg_route[param]["parse"](self, param, self.params, obj) + elif self.argument_spec[param].get("type") == "bool": if param in self.bool_values: - self._get_ansible_param_bool(obj, param, value=self.bool_values[param][1], value_false=self.bool_values[param][0], force=force) - elif self.bool_style == 'absent/present': - self._get_ansible_param_bool(obj, param, value='', force=force) + self._get_ansible_param_bool( + obj, + param, + value=self.bool_values[param][1], + value_false=self.bool_values[param][0], + force=force, + ) + elif self.bool_style == "absent/present": + self._get_ansible_param_bool(obj, param, value="", force=force) else: self._get_ansible_param_bool(obj, param, force=force) else: @@ -244,53 +302,61 @@ def _params_to_obj(self, obj=None): # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # Not all modules have 'state', treat them like they did - if self.params.get('state', 'present') == 'present': + if self.params.get("state", "present") == "present": # Ansible sets unspecied parameters to None, skip them for param in [p for p in self.params if self.params[p] is not None]: - if param in self.arg_route and 'validate' in self.arg_route[param]: + if param in self.arg_route and "validate" in self.arg_route[param]: try: - self.arg_route[param]['validate'](self, params[param]) + self.arg_route[param]["validate"](self, params[param]) except ValueError as e: self.module.fail_json(msg=str(e)) def _deprecated_params(self): - """ return deprecated params """ + """return deprecated params""" return None def _onward_params(self): - """ return onwards params """ + """return onwards params""" return None def _check_deprecated_params(self): - """ check if input parameters are deprecated """ + """check if input parameters are deprecated""" deprecated_params = self._deprecated_params() if deprecated_params is None: return for deprecated in deprecated_params: if self.params.get(deprecated[0]) is not None and deprecated[1](): - self.module.fail_json(msg='{0} is deprecated on pfSense {1}.'.format(deprecated[0], self.pfsense.get_version())) + self.module.fail_json( + msg="{0} is deprecated on pfSense {1}.".format( + deprecated[0], self.pfsense.get_version() + ) + ) def _check_onward_params(self): - """ check if input parameters are too recents """ + """check if input parameters are too recents""" onwards_params = self._onward_params() if onwards_params is None: return for onward in onwards_params: if self.params.get(onward[0]) is not None and not onward[1](): - self.module.fail_json(msg='{0} is not supported on pfSense {1}.'.format(onward[0], self.pfsense.get_version())) + self.module.fail_json( + msg="{0} is not supported on pfSense {1}.".format( + onward[0], self.pfsense.get_version() + ) + ) ############################## # XML processing # def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.obj + self.diff["after"] = self.obj if self.root_is_exclusive: self.root_elt.append(self.target_elt) else: @@ -299,24 +365,24 @@ def _copy_and_add_target(self): self.elements = self.root_elt.findall(self.node) def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) - self.diff['before'] = before + self.diff["before"] = before changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) if self._remove_deleted_params(): changed = True - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) return (before, changed) def _create_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" if self.node is not None: elt = self.pfsense.new_element(self.node) if self.have_refid: # Store in obj so that we can refer to it later if needed - self.obj['refid'] = self.pfsense.uniqid() - elt.append(self.pfsense.new_element('refid', text=self.obj['refid'])) + self.obj["refid"] = self.pfsense.uniqid() + elt.append(self.pfsense.new_element("refid", text=self.obj["refid"])) if self.create_default is not None: self.pfsense.copy_dict_to_element(self.create_default, elt) return elt @@ -333,115 +399,137 @@ def _find_last_element_index(self): return len(list(self.root_elt)) def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" if self.node is not None: - result = self.root_elt.findall("{node}[{key}='{value}']".format(node=self.node, key=self.key, value=self.obj[self.key])) + result = self.root_elt.findall( + "{node}[{key}='{value}']".format( + node=self.node, key=self.key, value=self.obj[self.key] + ) + ) if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple {node}s for {key} {value}.'.format(node=self.node, key=self.key, value=self.obj[self.key])) + self.module.fail_json( + msg="Found multiple {node}s for {key} {value}.".format( + node=self.node, key=self.key, value=self.obj[self.key] + ) + ) else: return None else: raise NotImplementedError() def _get_params_to_remove(self): - """ returns the list of params to remove if they are set to false """ + """returns the list of params to remove if they are set to false""" to_remove = [] # We need to remove any booleans set to false that are "None" when unset - for param in [n for n in self.argument_spec.keys() if self.argument_spec[n].get('type') == 'bool']: + for param in [ + n + for n in self.argument_spec.keys() + if self.argument_spec[n].get("type") == "bool" + ]: if self.params.get(param, None) is False: if param in self.bool_values and self.bool_values[param][0] is None: to_remove.append(param) - elif self.bool_style == 'absent/present': + elif self.bool_style == "absent/present": to_remove.append(param) return to_remove def _remove_deleted_params(self): - """ Remove from target_elt a few deleted params """ + """Remove from target_elt a few deleted params""" changed = False params = self._get_params_to_remove() for param in params: - if self.pfsense.remove_deleted_param_from_elt(self.target_elt, param, self.obj): + if self.pfsense.remove_deleted_param_from_elt( + self.target_elt, param, self.obj + ): changed = True return changed def _remove_target_elt(self): - """ delete target_elt from xml """ + """delete target_elt from xml""" self.root_elt.remove(self.target_elt) - self.result['changed'] = True + self.result["changed"] = True ############################## # run # def _add(self): - """ add or update obj """ + """add or update obj""" if self.target_elt is None: self.target_elt = self._create_target() self._copy_and_add_target() changed = True - self.change_descr = 'ansible {0} added {1}'.format(self._get_module_name(), self._get_obj_name()) + self.change_descr = "ansible {0} added {1}".format( + self._get_module_name(), self._get_obj_name() + ) self._log_create() else: (before, changed) = self._copy_and_update_target() if changed: - self.change_descr = 'ansible {0} updated {1}'.format(self._get_module_name(), self._get_obj_name()) + self.change_descr = "ansible {0} updated {1}".format( + self._get_module_name(), self._get_obj_name() + ) self._log_update(before) if changed: - self.result['changed'] = changed + self.result["changed"] = changed def commit_changes(self): - """ apply changes and exit module """ - self.result['stdout'] = '' - self.result['stderr'] = '' - if self.result['changed'] and not self.module.check_mode: + """apply changes and exit module""" + self.result["stdout"] = "" + self.result["stderr"] = "" + if self.result["changed"] and not self.module.check_mode: if self.apply: - (dummy, self.result['stdout'], self.result['stderr']) = self._pre_update() + (dummy, self.result["stdout"], self.result["stderr"]) = ( + self._pre_update() + ) self.pfsense.write_config(descr=self.change_descr) if self.apply: (dummy, stdout, stderr) = self._update() - self.result['stdout'] += stdout - self.result['stderr'] += stderr + self.result["stdout"] += stdout + self.result["stderr"] += stderr self.module.exit_json(**self.result) def _post_remove_target_elt(self): - """ processing after removing elt """ + """processing after removing elt""" pass def _pre_remove_target_elt(self): - """ processing before removing elt """ - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + """processing before removing elt""" + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) def _remove(self): - """ delete obj """ + """delete obj""" if self.target_elt is not None: self._pre_remove_target_elt() self._log_delete() self._remove_target_elt() self._post_remove_target_elt() - self.change_descr = 'ansible {0} removed {1}'.format(self._get_module_name(), self._get_obj_name()) + self.change_descr = "ansible {0} removed {1}".format( + self._get_module_name(), self._get_obj_name() + ) @staticmethod def _pre_update(): - """ tasks to run before making config changes """ - return ('', '', '') + """tasks to run before making config changes""" + return ("", "", "") def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" if self.update_php is not None: return self.pfsense.phpshell(self.update_php) else: - return ('', '', '') + return ("", "", "") # We take params here for use with pfsense_aggregate and the test framework def run(self, params): - """ process input params to add/update/delete """ + """process input params to add/update/delete""" self.params = params self.target_elt = None self._check_deprecated_params() @@ -452,7 +540,7 @@ def run(self, params): if self.target_elt is None: self.target_elt = self._find_target() - if params.get('state', None) == 'absent': + if params.get("state", None) == "absent": self._remove() else: self._add() @@ -461,54 +549,77 @@ def run(self, params): # Logging # def _log_create(self): - """ generate pseudo-CLI command to create an obj """ + """generate pseudo-CLI command to create an obj""" log = "create {0} {1}".format(self._get_module_name(True), self._get_obj_name()) log += self._log_fields() - self.result['commands'].append(log) + self.result["commands"].append(log) def _log_delete(self): - """ generate pseudo-CLI command to delete an obj """ + """generate pseudo-CLI command to delete an obj""" log = "delete {0} {1}".format(self._get_module_name(True), self._get_obj_name()) log += self._log_fields_delete() - self.result['commands'].append(log) + self.result["commands"].append(log) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - for param in [n for n in self.argument_spec.keys() if n != 'state' and n != self.key]: + for param in [ + n for n in self.argument_spec.keys() if n != "state" and n != self.key + ]: values += self.format_cli_field(self.obj, param) else: - for param in [n for n in self.argument_spec.keys() if n != 'state' and n != self.key]: - if self.argument_spec[param].get('type') == 'bool': - values += self.format_updated_cli_field(self.diff['after'], before, param, fvalue=self.fvalue_bool, add_comma=(values)) + for param in [ + n for n in self.argument_spec.keys() if n != "state" and n != self.key + ]: + if self.argument_spec[param].get("type") == "bool": + values += self.format_updated_cli_field( + self.diff["after"], + before, + param, + fvalue=self.fvalue_bool, + add_comma=(values), + ) else: - values += self.format_updated_cli_field(self.diff['after'], before, param, add_comma=(values)) + values += self.format_updated_cli_field( + self.diff["after"], before, param, add_comma=(values) + ) return values @staticmethod def _log_fields_delete(): - """ generate pseudo-CLI command fields parameters to delete an obj """ + """generate pseudo-CLI command fields parameters to delete an obj""" return "" def _log_update(self, before): - """ generate pseudo-CLI command to update an obj """ + """generate pseudo-CLI command to update an obj""" log = "update {0} {1}".format(self._get_module_name(True), self._get_obj_name()) values = self._log_fields(before) - self.result['commands'].append(log + ' set ' + values) + self.result["commands"].append(log + " set " + values) def _get_obj_name(self): - """ return obj's name """ + """return obj's name""" return "'{0}'".format(self.obj[self.key]) def _get_module_name(self, strip=False): - """ return ansible module's name """ + """return ansible module's name""" if strip: return self.name.replace("pfsense_", "") return self.name - def format_cli_field(self, after, field, log_none=False, add_comma=True, fvalue=None, default=None, fname=None, none_value=None, force=False): - """ format field for pseudo-CLI command """ + def format_cli_field( + self, + after, + field, + log_none=False, + add_comma=True, + fvalue=None, + default=None, + fname=None, + none_value=None, + force=False, + ): + """format field for pseudo-CLI command""" if fvalue is None: fvalue = self.fvalue_idem @@ -516,30 +627,43 @@ def format_cli_field(self, after, field, log_none=False, add_comma=True, fvalue= fname = field if none_value is None: - none_value = 'none' + none_value = "none" - res = '' + res = "" if field in after: if log_none and after[field] is None: res = "{0}={1}".format(fname, fvalue(none_value)) if after[field] is not None: if default is None or after[field] != default: if isinstance(after[field], str) and fvalue != self.fvalue_bool: - res = "{0}='{1}'".format(fname, fvalue(after[field].replace("'", "\\'"))) + res = "{0}='{1}'".format( + fname, fvalue(after[field].replace("'", "\\'")) + ) else: res = "{0}={1}".format(fname, fvalue(after[field])) elif log_none or force: res = "{0}={1}".format(fname, fvalue(none_value)) if add_comma and res: - return ', ' + res + return ", " + res return res - def format_updated_cli_field(self, after, before, field, log_none=True, add_comma=True, fvalue=None, default=None, fname=None, none_value=None): - """ format field for pseudo-CLI update command """ + def format_updated_cli_field( + self, + after, + before, + field, + log_none=True, + add_comma=True, + fvalue=None, + default=None, + fname=None, + none_value=None, + ): + """format field for pseudo-CLI update command""" log = False if none_value is None: - none_value = 'none' + none_value = "none" if field in after and field in before: if fvalue is None and after[field] != before[field]: @@ -547,28 +671,49 @@ def format_updated_cli_field(self, after, before, field, log_none=True, add_comm elif fvalue is not None and fvalue(after[field]) != fvalue(before[field]): log = True elif fvalue is None: - if field in after and field not in before or field not in after and field in before: + if ( + field in after + and field not in before + or field not in after + and field in before + ): log = True - elif field in after and field not in before and fvalue(after[field]) != fvalue(none_value): + elif ( + field in after + and field not in before + and fvalue(after[field]) != fvalue(none_value) + ): log = True - elif field not in after and field in before and fvalue(before[field]) != fvalue(none_value): + elif ( + field not in after + and field in before + and fvalue(before[field]) != fvalue(none_value) + ): log = True if log: return self.format_cli_field( - after, field, log_none=log_none, add_comma=add_comma, fvalue=fvalue, default=default, fname=fname, none_value=none_value, force=True + after, + field, + log_none=log_none, + add_comma=add_comma, + fvalue=fvalue, + default=default, + fname=fname, + none_value=none_value, + force=True, ) - return '' + return "" @staticmethod def fvalue_idem(value): - """ dummy value formatting function """ + """dummy value formatting function""" return value @staticmethod def fvalue_bool(value): - """ boolean value formatting function """ - if value is None or value is False or value == 'none': - return 'False' + """boolean value formatting function""" + if value is None or value is False or value == "none": + return "False" - return 'True' + return "True" diff --git a/plugins/module_utils/module_config_base.py b/plugins/module_utils/module_config_base.py index 354ee6cb..ed5afb1a 100644 --- a/plugins/module_utils/module_config_base.py +++ b/plugins/module_utils/module_config_base.py @@ -4,31 +4,67 @@ # Copyright: (c) 2024, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase, merge_dicts +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, + merge_dicts, +) class PFSenseModuleConfigBase(PFSenseModuleBase): - """ class for implementing pfSense modules that manage a set of configuration settings """ + """class for implementing pfSense modules that manage a set of configuration settings""" ############################## # init # - def __init__(self, module, pfsense=None, package=None, name=None, root=None, root_is_exclusive=True, create_root=False, node=None, key='descr', - update_php=None, arg_route=None, map_param=None, map_param_if=None, param_force=None, bool_style=None, bool_values=None, have_refid=False, - create_default=None): - super(PFSenseModuleConfigBase, self).__init__(module, pfsense=pfsense, package=package, name=name, root=root, root_is_exclusive=True, create_root=False, - update_php=update_php, arg_route=arg_route, map_param=map_param, map_param_if=map_param_if, - param_force=param_force, bool_style=bool_style, bool_values=bool_values, create_default=create_default) + def __init__( + self, + module, + pfsense=None, + package=None, + name=None, + root=None, + root_is_exclusive=True, + create_root=False, + node=None, + key="descr", + update_php=None, + arg_route=None, + map_param=None, + map_param_if=None, + param_force=None, + bool_style=None, + bool_values=None, + have_refid=False, + create_default=None, + ): + super(PFSenseModuleConfigBase, self).__init__( + module, + pfsense=pfsense, + package=package, + name=name, + root=root, + root_is_exclusive=True, + create_root=False, + update_php=update_php, + arg_route=arg_route, + map_param=map_param, + map_param_if=map_param_if, + param_force=param_force, + bool_style=bool_style, + bool_values=bool_values, + create_default=create_default, + ) ############################## # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" obj = self.pfsense.element_to_dict(self.root_elt) merge_dicts(obj, super(PFSenseModuleConfigBase, self)._params_to_obj(obj=obj)) return obj @@ -37,12 +73,12 @@ def _params_to_obj(self): # XML processing # def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" return self.root_elt ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return re.sub(r'pfsense_', '', self.name) + """return obj's name""" + return re.sub(r"pfsense_", "", self.name) diff --git a/plugins/module_utils/nat_outbound.py b/plugins/module_utils/nat_outbound.py index d70f9d1a..9fa6dd7a 100644 --- a/plugins/module_utils/nat_outbound.py +++ b/plugins/module_utils/nat_outbound.py @@ -4,38 +4,70 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) from string import hexdigits from hashlib import md5 import re import sys NAT_OUTBOUND_ARGUMENT_SPEC = dict( - descr=dict(required=True, type='str'), - state=dict(default='present', choices=['present', 'absent']), - disabled=dict(default=False, required=False, type='bool'), - nonat=dict(default=False, required=False, type='bool'), - interface=dict(required=False, type='str'), - ipprotocol=dict(required=False, default='inet46', choices=['inet', 'inet46', 'inet6']), - protocol=dict(default='any', required=False, choices=["any", "tcp", "udp", "tcp/udp", "icmp", "esp", "ah", "gre", "ipv6", "igmp", "carp", "pfsync"]), - source=dict(required=False, type='str'), - destination=dict(required=False, type='str'), - invert=dict(default=False, required=False, type='bool'), - address=dict(required=False, type='str'), + descr=dict(required=True, type="str"), + state=dict(default="present", choices=["present", "absent"]), + disabled=dict(default=False, required=False, type="bool"), + nonat=dict(default=False, required=False, type="bool"), + interface=dict(required=False, type="str"), + ipprotocol=dict( + required=False, default="inet46", choices=["inet", "inet46", "inet6"] + ), + protocol=dict( + default="any", + required=False, + choices=[ + "any", + "tcp", + "udp", + "tcp/udp", + "icmp", + "esp", + "ah", + "gre", + "ipv6", + "igmp", + "carp", + "pfsync", + ], + ), + source=dict(required=False, type="str"), + destination=dict(required=False, type="str"), + invert=dict(default=False, required=False, type="bool"), + address=dict(required=False, type="str"), poolopts=dict( - default='', required=False, choices=["", "round-robin", "round-robin sticky-address", "random", "random sticky-address", "source-hash", "bitmask"] + default="", + required=False, + choices=[ + "", + "round-robin", + "round-robin sticky-address", + "random", + "random sticky-address", + "source-hash", + "bitmask", + ], ), - source_hash_key=dict(default='', type='str', no_log=True), - staticnatport=dict(default=False, required=False, type='bool'), - nosync=dict(default=False, required=False, type='bool'), - after=dict(required=False, type='str'), - before=dict(required=False, type='str'), + source_hash_key=dict(default="", type="str", no_log=True), + staticnatport=dict(default=False, required=False, type="bool"), + nosync=dict(default=False, required=False, type="bool"), + after=dict(required=False, type="str"), + before=dict(required=False, type="str"), ) NAT_OUTBOUND_MUTUALLY_EXCLUSIVE = [ - ('after', 'before'), + ("after", "before"), ] NAT_OUTBOUND_REQUIRED_IF = [ @@ -44,10 +76,10 @@ # Booleans that map to different values NAT_OUTBOUND_BOOL_VALUES = dict( - disabled=(None, ''), - staticnatport=(None, ''), - nonat=(None, ''), - nosync=(None, ''), + disabled=(None, ""), + staticnatport=(None, ""), + nonat=(None, ""), + nosync=(None, ""), ) @@ -61,13 +93,13 @@ def p2o_before(self, name, params, obj): def p2o_ipprotocol(self, name, params, obj): # IPv4+6 is marked by the absense of an ipprotocol element - if params[name] != 'inet46': + if params[name] != "inet46": self.obj[name] = params[name] def p2o_protocol(self, name, params, obj): # 'any' is marked by the absense of a protocol element - if params[name] != 'any': + if params[name] != "any": self.obj[name] = params[name] @@ -80,19 +112,25 @@ def p2o_protocol(self, name, params, obj): class PFSenseNatOutboundModule(PFSenseModuleBase): - """ module managing pfsense NAT rules """ + """module managing pfsense NAT rules""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return NAT_OUTBOUND_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseNatOutboundModule, self).__init__(module, pfsense, root='nat/outbound', create_root=True, arg_route=NAT_OUTBOUND_ARG_ROUTE, - bool_values=NAT_OUTBOUND_BOOL_VALUES) + super(PFSenseNatOutboundModule, self).__init__( + module, + pfsense, + root="nat/outbound", + create_root=True, + arg_route=NAT_OUTBOUND_ARG_ROUTE, + bool_values=NAT_OUTBOUND_BOOL_VALUES, + ) self.name = "pfsense_nat_outbound" # Override for use with aggregate self.argument_spec = NAT_OUTBOUND_ARGUMENT_SPEC @@ -105,72 +143,96 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" obj = super(PFSenseNatOutboundModule, self)._params_to_obj() params = self.params - if params['state'] == 'present': - self._parse_address(obj, 'source', 'sourceport', True, 'network') - self._parse_address(obj, 'destination', 'dstport', False, 'network') - if params['invert']: - obj['destination']['not'] = None + if params["state"] == "present": + self._parse_address(obj, "source", "sourceport", True, "network") + self._parse_address(obj, "destination", "dstport", False, "network") + if params["invert"]: + obj["destination"]["not"] = None self._parse_translated_address(obj) - if obj['source_hash_key'] != '' and not obj['source_hash_key'].startswith('0x'): + if obj["source_hash_key"] != "" and not obj["source_hash_key"].startswith( + "0x" + ): if sys.version_info[0] >= 3: - obj['source_hash_key'] = '0x' + md5(obj['source_hash_key'].encode('utf-8')).hexdigest() + obj["source_hash_key"] = ( + "0x" + md5(obj["source_hash_key"].encode("utf-8")).hexdigest() + ) else: - obj['source_hash_key'] = '0x' + md5(obj['source_hash_key']).hexdigest() + obj["source_hash_key"] = ( + "0x" + md5(obj["source_hash_key"]).hexdigest() + ) return obj def _parse_address(self, obj, field, field_port, allow_self, target): - """ validate param address field and returns it as a dict """ - if self.params.get(field) is None or self.params[field] == '': + """validate param address field and returns it as a dict""" + if self.params.get(field) is None or self.params[field] == "": return param = self.params[field] - addr = param.split(':') + addr = param.split(":") if len(addr) > 3: - self.module.fail_json(msg='Cannot parse address %s' % (param)) + self.module.fail_json(msg="Cannot parse address %s" % (param)) address = addr[0] ret = dict() - if address == 'NET': + if address == "NET": interface = addr[1] if len(addr) > 1 else None ports = addr[2] if len(addr) > 2 else None - if interface is None or interface == '': - self.module.fail_json(msg='Cannot parse address %s' % (param)) + if interface is None or interface == "": + self.module.fail_json(msg="Cannot parse address %s" % (param)) - ret['network'] = self.pfsense.parse_interface(interface) + ret["network"] = self.pfsense.parse_interface(interface) else: ports = addr[1] if len(addr) > 1 else None - if address == 'any': - if field == 'source': - ret[target] = 'any' + if address == "any": + if field == "source": + ret[target] = "any" else: - ret['any'] = '' + ret["any"] = "" # rule with this firewall - elif allow_self and address == '(self)': - ret[target] = '(self)' - elif self.params['ipprotocol'] != 'inet6' and self.pfsense.is_ipv4_address(address): - ret[target] = address + '/32' - self.module.warn('Specifying an address without a CIDR prefix is depracated. Please add /32 if you want a single host address') - elif self.params['ipprotocol'] != 'inet4' and self.pfsense.is_ipv6_address(address): - ret[target] = address + '/128' - self.module.warn('Specifying an address without a CIDR prefix is depracated. Please add /128 if you want a single host address') - elif self.params['ipprotocol'] != 'inet6' and self.pfsense.is_ipv4_network(address, False): + elif allow_self and address == "(self)": + ret[target] = "(self)" + elif self.params["ipprotocol"] != "inet6" and self.pfsense.is_ipv4_address( + address + ): + ret[target] = address + "/32" + self.module.warn( + "Specifying an address without a CIDR prefix is depracated. Please add /32 if you want a single host address" + ) + elif self.params["ipprotocol"] != "inet4" and self.pfsense.is_ipv6_address( + address + ): + ret[target] = address + "/128" + self.module.warn( + "Specifying an address without a CIDR prefix is depracated. Please add /128 if you want a single host address" + ) + elif self.params["ipprotocol"] != "inet6" and self.pfsense.is_ipv4_network( + address, False + ): (addr, bits) = self.pfsense.parse_ip_network(address, False, False) - ret[target] = addr + '/' + str(bits) - elif self.params['ipprotocol'] != 'inet4' and self.pfsense.is_ipv6_network(address, False): + ret[target] = addr + "/" + str(bits) + elif self.params["ipprotocol"] != "inet4" and self.pfsense.is_ipv6_network( + address, False + ): (addr, bits) = self.pfsense.parse_ip_network(address, False, False) - ret[target] = addr + '/' + str(bits) - elif self.pfsense.find_alias(address, 'host') is not None or self.pfsense.find_alias(address, 'network') is not None: + ret[target] = addr + "/" + str(bits) + elif ( + self.pfsense.find_alias(address, "host") is not None + or self.pfsense.find_alias(address, "network") is not None + ): ret[target] = address else: - self.module.fail_json(msg='Cannot parse address %s, not %s network or alias' % (address, self.params['ipprotocol'])) + self.module.fail_json( + msg="Cannot parse address %s, not %s network or alias" + % (address, self.params["ipprotocol"]) + ) if ports is not None: self._parse_ports(obj, ports, field_port, param) @@ -178,89 +240,113 @@ def _parse_address(self, obj, field, field_port, allow_self, target): obj[field] = ret def _parse_ports(self, obj, ports, field_port, param): - """ validate param address field and returns it as a dict """ + """validate param address field and returns it as a dict""" if ports is not None: - ports = ports.split('-') - if len(ports) > 2 or ports[0] is None or ports[0] == '' or len(ports) == 2 and (ports[1] is None or ports[1] == ''): - self.module.fail_json(msg='Cannot parse address %s' % (param)) + ports = ports.split("-") + if ( + len(ports) > 2 + or ports[0] is None + or ports[0] == "" + or len(ports) == 2 + and (ports[1] is None or ports[1] == "") + ): + self.module.fail_json(msg="Cannot parse address %s" % (param)) if not self.pfsense.is_port_or_alias(ports[0]): - self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[0])) + self.module.fail_json( + msg="Cannot parse port %s, not port number or alias" % (ports[0]) + ) obj[field_port] = ports[0] if len(ports) > 1: if not self.pfsense.is_port_or_alias(ports[1]): - self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[1])) - obj[field_port] += ':' + ports[1] + self.module.fail_json( + msg="Cannot parse port %s, not port number or alias" + % (ports[1]) + ) + obj[field_port] += ":" + ports[1] def _parse_translated_address(self, obj): - """ validate param address field and returns it as a dict """ - obj['target'] = '' - obj['target_subnet'] = '' + """validate param address field and returns it as a dict""" + obj["target"] = "" + obj["target_subnet"] = "" - if self.params.get('address') is None or self.params['address'] == '': + if self.params.get("address") is None or self.params["address"] == "": return - param = self.params['address'] - addr = param.split(':') + param = self.params["address"] + addr = param.split(":") if len(addr) > 2: - self.module.fail_json(msg='Cannot parse address %s' % (param)) + self.module.fail_json(msg="Cannot parse address %s" % (param)) address = addr[0] ports = addr[1] if len(addr) > 1 else None - if address is not None and address != '': + if address is not None and address != "": if self.pfsense.is_virtual_ip(address): - obj['target'] = address - obj['target_subnet'] = None - elif self.pfsense.find_alias(address, 'host') is not None or self.pfsense.find_alias(address, 'network') is not None: - obj['target'] = address - if obj['poolopts'] != '' and not obj['poolopts'].startswith('round-robin'): - self.module.fail_json(msg='Only Round Robin pool options may be chosen when selecting an alias.') - obj['target_subnet'] = '32' + obj["target"] = address + obj["target_subnet"] = None + elif ( + self.pfsense.find_alias(address, "host") is not None + or self.pfsense.find_alias(address, "network") is not None + ): + obj["target"] = address + if obj["poolopts"] != "" and not obj["poolopts"].startswith( + "round-robin" + ): + self.module.fail_json( + msg="Only Round Robin pool options may be chosen when selecting an alias." + ) + obj["target_subnet"] = "32" elif self.pfsense.is_ipv4_address(address): - obj['target'] = address - obj['target_subnet'] = '32' + obj["target"] = address + obj["target_subnet"] = "32" else: (addr, part) = self.pfsense.parse_ip_network(address, False, False) if addr is None: - self.module.fail_json(msg='Cannot parse address %s, not IP or alias' % (address)) - obj['target'] = addr - obj['target_subnet'] = str(part) - del obj['address'] + self.module.fail_json( + msg="Cannot parse address %s, not IP or alias" % (address) + ) + obj["target"] = addr + obj["target_subnet"] = str(part) + del obj["address"] - self._parse_ports(obj, ports, 'natport', param) + self._parse_ports(obj, ports, "natport", param) def _validate_params(self): - """ do some extra checks on input parameters """ - - if self.params.get('after'): - if self.params['after'] == self.params['descr']: - self.module.fail_json(msg='Cannot specify the current rule in after') - elif self.params.get('before'): - if self.params['before'] == self.params['descr']: - self.module.fail_json(msg='Cannot specify the current rule in before') - - if self.params.get('source_hash_key') is not None and self.params['source_hash_key'].startswith('0x'): - hash = self.params['source_hash_key'][2:] + """do some extra checks on input parameters""" + + if self.params.get("after"): + if self.params["after"] == self.params["descr"]: + self.module.fail_json(msg="Cannot specify the current rule in after") + elif self.params.get("before"): + if self.params["before"] == self.params["descr"]: + self.module.fail_json(msg="Cannot specify the current rule in before") + + if self.params.get("source_hash_key") is not None and self.params[ + "source_hash_key" + ].startswith("0x"): + hash = self.params["source_hash_key"][2:] if len(hash) != 32 or not all(c in hexdigits for c in hash): - self.module.fail_json(msg='Incorrect format for source-hash key, \"0x\" must be followed by exactly 32 hexadecimal characters.') + self.module.fail_json( + msg='Incorrect format for source-hash key, "0x" must be followed by exactly 32 hexadecimal characters.' + ) ############################## # XML processing # def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.obj + self.diff["after"] = self.obj self._insert(self.target_elt) def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) - self.diff['before'] = before + self.diff["before"] = before changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) if self._remove_deleted_params(): changed = True @@ -270,44 +356,44 @@ def _copy_and_update_target(self): return (before, changed) def _create_target(self): - """ create the XML target_elt """ - target_elt = self.pfsense.new_element('rule') + """create the XML target_elt""" + target_elt = self.pfsense.new_element("rule") return target_elt def _find_first_rule_idx(self): - """ find the XML first rule idx """ + """find the XML first rule idx""" for idx, rule_elt in enumerate(self.root_elt): - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue return idx return len(self.root_elt) def _find_rule_by_descr(self, descr): - """ find the XML target_elt """ + """find the XML target_elt""" for idx, rule_elt in enumerate(self.root_elt): - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue - if rule_elt.find('descr').text == descr: + if rule_elt.find("descr").text == descr: return (rule_elt, idx) return (None, None) def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" for rule_elt in self.root_elt: - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue - if rule_elt.find('descr').text == self.obj['descr']: + if rule_elt.find("descr").text == self.obj["descr"]: return rule_elt return None def _get_expected_rule_position(self): - """ get expected rule position in interface/floating """ - if self.before == 'bottom': + """get expected rule position in interface/floating""" + if self.before == "bottom": return len(self.root_elt) - elif self.after == 'top': + elif self.after == "top": return self._find_first_rule_idx() elif self.after is not None: return self._get_rule_position(self.after) + 1 @@ -324,25 +410,29 @@ def _get_expected_rule_position(self): return -1 def _get_expected_rule_xml_index(self): - """ get expected rule index in xml """ - if self.before == 'bottom': + """get expected rule index in xml""" + if self.before == "bottom": return len(self.root_elt) - elif self.after == 'top': + elif self.after == "top": return self._find_first_rule_idx() elif self.after is not None: found, i = self._find_rule_by_descr(self.after) if found is not None: return i + 1 else: - self.module.fail_json(msg='Failed to insert after rule=%s' % (self.after)) + self.module.fail_json( + msg="Failed to insert after rule=%s" % (self.after) + ) elif self.before is not None: found, i = self._find_rule_by_descr(self.before) if found is not None: return i else: - self.module.fail_json(msg='Failed to insert before rule=%s' % (self.before)) + self.module.fail_json( + msg="Failed to insert before rule=%s" % (self.before) + ) else: - found, i = self._find_rule_by_descr(self.obj['descr']) + found, i = self._find_rule_by_descr(self.obj["descr"]) if found is not None: return i return len(self.root_elt) @@ -350,26 +440,36 @@ def _get_expected_rule_xml_index(self): @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - return ['disabled', 'nonat', 'invert', 'staticnatport', 'nosync', 'dstport', 'natport', 'ipprotocol', 'protocol'] + """returns the list of params to remove if they are not set""" + return [ + "disabled", + "nonat", + "invert", + "staticnatport", + "nosync", + "dstport", + "natport", + "ipprotocol", + "protocol", + ] def _get_rule_position(self, descr=None, fail=True): - """ get rule position in interface/floating """ + """get rule position in interface/floating""" if descr is None: - descr = self.obj['descr'] + descr = self.obj["descr"] (res, idx) = self._find_rule_by_descr(descr) if fail and res is None: - self.module.fail_json(msg='Failed to find rule=%s' % (descr)) + self.module.fail_json(msg="Failed to find rule=%s" % (descr)) return idx def _insert(self, rule_elt): - """ insert rule into xml """ + """insert rule into xml""" rule_xml_idx = self._get_expected_rule_xml_index() self.root_elt.insert(rule_xml_idx, rule_elt) def _update_rule_position(self, rule_elt): - """ move rule in xml if required """ + """move rule in xml if required""" current_position = self._get_rule_position() expected_position = self._get_expected_rule_position() if current_position == expected_position: @@ -385,103 +485,189 @@ def _update_rule_position(self, rule_elt): # run # def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell('''require_once("filter.inc"); -if (filter_configure() == 0) { clear_subsystem_dirty('natconf'); clear_subsystem_dirty('filter'); }''') + """make the target pfsense reload""" + return self.pfsense.phpshell( + """require_once("filter.inc"); +if (filter_configure() == 0) { clear_subsystem_dirty('natconf'); clear_subsystem_dirty('filter'); }""" + ) ############################## # Logging # @staticmethod def fvalue_protocol(value): - """ boolean value formatting function """ - if value is None or value == 'none': - return 'any' + """boolean value formatting function""" + if value is None or value == "none": + return "any" return value @staticmethod def fvalue_ipprotocol(value): - """ boolean value formatting function """ - if value is None or value == 'none': - return 'inet46' + """boolean value formatting function""" + if value is None or value == "none": + return "inet46" return value def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" fafter = self._obj_to_log_fields(self.obj) if before is None: - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'nonat', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'interface') - values += self.format_cli_field(self.params, 'ipprotocol', fvalue=self.fvalue_ipprotocol, default='inet46') - values += self.format_cli_field(self.params, 'protocol', fvalue=self.fvalue_protocol, default='any') - values += self.format_cli_field(self.params, 'source') - values += self.format_cli_field(self.params, 'destination') - values += self.format_cli_field(self.params, 'invert', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(fafter, 'address', default='') - values += self.format_cli_field(self.params, 'poolopts', default='') - values += self.format_cli_field(self.obj, 'source_hash_key', default='') - values += self.format_cli_field(self.params, 'staticnatport', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'nosync', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'after') - values += self.format_cli_field(self.params, 'before') + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field( + self.params, "nonat", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.params, "interface") + values += self.format_cli_field( + self.params, + "ipprotocol", + fvalue=self.fvalue_ipprotocol, + default="inet46", + ) + values += self.format_cli_field( + self.params, "protocol", fvalue=self.fvalue_protocol, default="any" + ) + values += self.format_cli_field(self.params, "source") + values += self.format_cli_field(self.params, "destination") + values += self.format_cli_field( + self.params, "invert", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(fafter, "address", default="") + values += self.format_cli_field(self.params, "poolopts", default="") + values += self.format_cli_field(self.obj, "source_hash_key", default="") + values += self.format_cli_field( + self.params, "staticnatport", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field( + self.params, "nosync", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.params, "after") + values += self.format_cli_field(self.params, "before") else: fbefore = self._obj_to_log_fields(before) - fafter['before'] = self.before - fafter['after'] = self.after - - values += self.format_updated_cli_field(self.obj, before, 'disabled', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'nonat', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'interface', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'ipprotocol', fvalue=self.fvalue_ipprotocol, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'protocol', fvalue=self.fvalue_protocol, add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'source', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'destination', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'invert', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(fafter, before, 'address', default='', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'poolopts', default='', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'source_hash_key', default='', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'staticnatport', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'nosync', fvalue=self.fvalue_bool, default=False, add_comma=(values)) + fafter["before"] = self.before + fafter["after"] = self.after + + values += self.format_updated_cli_field( + self.obj, + before, + "disabled", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "nonat", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + fafter, fbefore, "interface", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "ipprotocol", + fvalue=self.fvalue_ipprotocol, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "protocol", + fvalue=self.fvalue_protocol, + add_comma=(values), + ) + values += self.format_updated_cli_field( + fafter, fbefore, "source", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "destination", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "invert", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + fafter, before, "address", default="", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "poolopts", default="", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "source_hash_key", default="", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "staticnatport", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "nosync", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) if self.position_changed: - values += self.format_updated_cli_field(fafter, {}, 'after', log_none=False, add_comma=(values)) - values += self.format_updated_cli_field(fafter, {}, 'before', log_none=False, add_comma=(values)) + values += self.format_updated_cli_field( + fafter, {}, "after", log_none=False, add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, {}, "before", log_none=False, add_comma=(values) + ) return values def _obj_address_to_log_field(self, rule, addr, target, port): - """ return formated address from dict """ - field = '' + """return formated address from dict""" + field = "" if addr in rule: if target in rule[addr]: if self.pfsense.interfaces.find(rule[addr][target]): - field = 'NET:' + field = "NET:" field += rule[addr][target] - elif addr == 'destination' and 'any' in rule[addr]: - field = 'any' + elif addr == "destination" and "any" in rule[addr]: + field = "any" - if port in rule and rule[port] is not None and rule[port] != '': - field += ':' - field += rule[port].replace(':', '-') + if port in rule and rule[port] is not None and rule[port] != "": + field += ":" + field += rule[port].replace(":", "-") return field def _obj_to_log_fields(self, rule): - """ return formated source and destination from dict """ + """return formated source and destination from dict""" res = {} - res['source'] = self._obj_address_to_log_field(rule, 'source', 'network', 'sourceport') - res['destination'] = self._obj_address_to_log_field(rule, 'destination', 'network', 'dstport') - res['interface'] = self.pfsense.get_interface_display_name(rule['interface']) - - if rule.get('target', '') != '': - if re.match(r'[a-zA-Z]', rule['target']): - res['address'] = rule['target'] + res["source"] = self._obj_address_to_log_field( + rule, "source", "network", "sourceport" + ) + res["destination"] = self._obj_address_to_log_field( + rule, "destination", "network", "dstport" + ) + res["interface"] = self.pfsense.get_interface_display_name(rule["interface"]) + + if rule.get("target", "") != "": + if re.match(r"[a-zA-Z]", rule["target"]): + res["address"] = rule["target"] else: - res['address'] = rule['target'] + '/' + rule['target_subnet'] - if rule.get('natport', '') != '': - res['address'] += ':' - res['address'] += rule['natport'].replace(':', '-') + res["address"] = rule["target"] + "/" + rule["target_subnet"] + if rule.get("natport", "") != "": + res["address"] += ":" + res["address"] += rule["natport"].replace(":", "-") return res diff --git a/plugins/module_utils/nat_port_forward.py b/plugins/module_utils/nat_port_forward.py index d0899289..d7a8b568 100644 --- a/plugins/module_utils/nat_port_forward.py +++ b/plugins/module_utils/nat_port_forward.py @@ -5,27 +5,55 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase -from ansible_collections.pfsensible.core.plugins.module_utils.rule import PFSenseRuleModule +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule import ( + PFSenseRuleModule, +) NAT_PORT_FORWARD_ARGUMENT_SPEC = dict( - descr=dict(required=True, type='str'), - state=dict(default='present', choices=['present', 'absent']), - disabled=dict(default=False, required=False, type='bool'), - nordr=dict(default=False, required=False, type='bool'), - interface=dict(required=False, type='str'), - ipprotocol=dict(default='inet', choices=['inet', 'inet6']), - protocol=dict(default='tcp', required=False, choices=["tcp", "udp", "tcp/udp", "icmp", "esp", "ah", "gre", "ipv6", "igmp", "pim", "ospf"]), - source=dict(required=False, type='str'), - destination=dict(required=False, type='str'), - target=dict(required=False, type='str'), - natreflection=dict(default='system-default', choices=["system-default", "enable", "purenat", "disable"]), - associated_rule=dict(default='associated', required=False, choices=["associated", "unassociated", "pass", "none"]), - nosync=dict(default=False, required=False, type='bool'), - after=dict(required=False, type='str'), - before=dict(required=False, type='str'), + descr=dict(required=True, type="str"), + state=dict(default="present", choices=["present", "absent"]), + disabled=dict(default=False, required=False, type="bool"), + nordr=dict(default=False, required=False, type="bool"), + interface=dict(required=False, type="str"), + ipprotocol=dict(default="inet", choices=["inet", "inet6"]), + protocol=dict( + default="tcp", + required=False, + choices=[ + "tcp", + "udp", + "tcp/udp", + "icmp", + "esp", + "ah", + "gre", + "ipv6", + "igmp", + "pim", + "ospf", + ], + ), + source=dict(required=False, type="str"), + destination=dict(required=False, type="str"), + target=dict(required=False, type="str"), + natreflection=dict( + default="system-default", + choices=["system-default", "enable", "purenat", "disable"], + ), + associated_rule=dict( + default="associated", + required=False, + choices=["associated", "unassociated", "pass", "none"], + ), + nosync=dict(default=False, required=False, type="bool"), + after=dict(required=False, type="str"), + before=dict(required=False, type="str"), ) NAT_PORT_FORWARD_REQUIRED_IF = [ @@ -34,11 +62,11 @@ class PFSenseNatPortForwardModule(PFSenseModuleBase): - """ module managing pfsense NAT rules """ + """module managing pfsense NAT rules""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return NAT_PORT_FORWARD_ARGUMENT_SPEC ############################## @@ -55,9 +83,9 @@ def __init__(self, module, pfsense=None): self.before = None self.position_changed = False - self.root_elt = self.pfsense.get_element('nat') + self.root_elt = self.pfsense.get_element("nat") if self.root_elt is None: - self.root_elt = self.pfsense.new_element('nat') + self.root_elt = self.pfsense.new_element("nat") self.pfsense.root.append(self.root_elt) self.pfsense_rule_module = None @@ -66,107 +94,131 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" obj = dict() self.obj = obj - obj['descr'] = self.params['descr'] - if self.params['state'] == 'present': - obj['interface'] = self.pfsense.parse_interface(self.params['interface']) + obj["descr"] = self.params["descr"] + if self.params["state"] == "present": + obj["interface"] = self.pfsense.parse_interface(self.params["interface"]) if self.pfsense.is_at_least_2_5_0(): - self._get_ansible_param(obj, 'ipprotocol') - self._get_ansible_param(obj, 'protocol') - self._get_ansible_param(obj, 'poolopts') - self._get_ansible_param(obj, 'source_hash_key') - self._get_ansible_param(obj, 'natport') - - self._get_ansible_param(obj, 'natreflection') - if obj['natreflection'] == 'system-default': - del obj['natreflection'] - - if self.params['associated_rule'] == 'pass': - obj['associated-rule-id'] = 'pass' - elif self.params['associated_rule'] == 'unassociated' and self._find_target() is not None: - self.module.fail_json(msg='You cannot set an unassociated filter rule if the NAT rule already exists.') + self._get_ansible_param(obj, "ipprotocol") + self._get_ansible_param(obj, "protocol") + self._get_ansible_param(obj, "poolopts") + self._get_ansible_param(obj, "source_hash_key") + self._get_ansible_param(obj, "natport") + + self._get_ansible_param(obj, "natreflection") + if obj["natreflection"] == "system-default": + del obj["natreflection"] + + if self.params["associated_rule"] == "pass": + obj["associated-rule-id"] = "pass" + elif ( + self.params["associated_rule"] == "unassociated" + and self._find_target() is not None + ): + self.module.fail_json( + msg="You cannot set an unassociated filter rule if the NAT rule already exists." + ) else: - obj['associated-rule-id'] = '' + obj["associated-rule-id"] = "" - self._get_ansible_param_bool(obj, 'disabled') - self._get_ansible_param_bool(obj, 'nordr') - self._get_ansible_param_bool(obj, 'nosync') + self._get_ansible_param_bool(obj, "disabled") + self._get_ansible_param_bool(obj, "nordr") + self._get_ansible_param_bool(obj, "nosync") - if 'after' in self.params and self.params['after'] is not None: - self.after = self.params['after'] + if "after" in self.params and self.params["after"] is not None: + self.after = self.params["after"] - if 'before' in self.params and self.params['before'] is not None: - self.before = self.params['before'] + if "before" in self.params and self.params["before"] is not None: + self.before = self.params["before"] - obj['source'] = self.pfsense.parse_address(self.params['source'], allow_self=False) - obj['destination'] = self.pfsense.parse_address(self.params['destination']) + obj["source"] = self.pfsense.parse_address( + self.params["source"], allow_self=False + ) + obj["destination"] = self.pfsense.parse_address(self.params["destination"]) self._parse_target_address(obj) return obj def _parse_target_address(self, obj): - """ validate param address field and returns it as a dict """ + """validate param address field and returns it as a dict""" - if self.params.get('target') is None or self.params['target'] == '': - self.module.fail_json(msg='The field Redirect target IP is required.') + if self.params.get("target") is None or self.params["target"] == "": + self.module.fail_json(msg="The field Redirect target IP is required.") - param = self.params['target'] - addr = param.split(':') + param = self.params["target"] + addr = param.split(":") if len(addr) > 2: - self.module.fail_json(msg='Cannot parse address %s' % (param)) + self.module.fail_json(msg="Cannot parse address %s" % (param)) address = addr[0] ports = addr[1] if len(addr) > 1 else None - if self.pfsense.find_alias(address, 'host') is not None or self.pfsense.is_ipv4_address(address): - obj['target'] = address + if self.pfsense.find_alias( + address, "host" + ) is not None or self.pfsense.is_ipv4_address(address): + obj["target"] = address else: - self.module.fail_json(msg='"%s" is not a valid redirect target IP address or host alias.' % (param)) + self.module.fail_json( + msg='"%s" is not a valid redirect target IP address or host alias.' + % (param) + ) if ports is None: - if self.params['protocol'] in ["tcp", "udp", "tcp/udp"]: - self.module.fail_json(msg='Must specify a target port with protocol "{0}".'.format(self.params['protocol'])) + if self.params["protocol"] in ["tcp", "udp", "tcp/udp"]: + self.module.fail_json( + msg='Must specify a target port with protocol "{0}".'.format( + self.params["protocol"] + ) + ) else: # pfSense seems to always add an empty local-port element - obj['local-port'] = '' + obj["local-port"] = "" if ports is not None: - if self.params['protocol'] not in ["tcp", "udp", "tcp/udp"]: - self.module.fail_json(msg='Cannot specify a target port with protocol "{0}".'.format(self.params['protocol'])) + if self.params["protocol"] not in ["tcp", "udp", "tcp/udp"]: + self.module.fail_json( + msg='Cannot specify a target port with protocol "{0}".'.format( + self.params["protocol"] + ) + ) elif self.pfsense.is_port_or_alias(ports): - obj['local-port'] = ports + obj["local-port"] = ports else: - self.module.fail_json(msg='"{0}" is not a valid redirect target port. It must be a port alias or integer between 1 and 65535.'.format(ports)) + self.module.fail_json( + msg='"{0}" is not a valid redirect target port. It must be a port alias or integer between 1 and 65535.'.format( + ports + ) + ) def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" - if self.params.get('after') and self.params.get('before'): - self.module.fail_json(msg='Cannot specify both after and before') - elif self.params.get('after'): - if self.params['after'] == self.params['descr']: - self.module.fail_json(msg='Cannot specify the current rule in after') - elif self.params.get('before'): - if self.params['before'] == self.params['descr']: - self.module.fail_json(msg='Cannot specify the current rule in before') + if self.params.get("after") and self.params.get("before"): + self.module.fail_json(msg="Cannot specify both after and before") + elif self.params.get("after"): + if self.params["after"] == self.params["descr"]: + self.module.fail_json(msg="Cannot specify the current rule in after") + elif self.params.get("before"): + if self.params["before"] == self.params["descr"]: + self.module.fail_json(msg="Cannot specify the current rule in before") ############################## # XML processing # def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" self._set_associated_rule() self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) self._insert(self.target_elt) def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) - self.diff['before'] = before + self.diff["before"] = before changed = self._set_associated_rule(before) if self.pfsense.copy_dict_to_element(self.obj, self.target_elt): @@ -178,79 +230,79 @@ def _copy_and_update_target(self): if self._update_rule_position(self.target_elt): changed = True - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) return (before, changed) def _create_associated_rule(self): if self.pfsense_rule_module is None: self.pfsense_rule_module = PFSenseRuleModule(self.module, self.pfsense) params = dict() - params['name'] = 'NAT ' + self.params['descr'] - params['state'] = 'present' - params['action'] = 'pass' + params["name"] = "NAT " + self.params["descr"] + params["state"] = "present" + params["action"] = "pass" if self.pfsense.is_at_least_2_5_0(): - params['ipprotocol'] = 'inet' - params['statetype'] = 'keep state' - params['interface'] = self.params['interface'] - params['source'] = self.params['source'] - params['destination'] = self.params['target'] - params['disabled'] = self.params['disabled'] - params['protocol'] = self.params['protocol'] - if self.params['associated_rule'] == 'associated': - params['associated-rule-id'] = self.pfsense.uniqid('nat_', True) - self.obj['associated-rule-id'] = params['associated-rule-id'] - self.result['commands'] = list() + params["ipprotocol"] = "inet" + params["statetype"] = "keep state" + params["interface"] = self.params["interface"] + params["source"] = self.params["source"] + params["destination"] = self.params["target"] + params["disabled"] = self.params["disabled"] + params["protocol"] = self.params["protocol"] + if self.params["associated_rule"] == "associated": + params["associated-rule-id"] = self.pfsense.uniqid("nat_", True) + self.obj["associated-rule-id"] = params["associated-rule-id"] + self.result["commands"] = list() self.pfsense_rule_module.run(params) - self.result['commands'] += self.pfsense_rule_module.result['commands'] + self.result["commands"] += self.pfsense_rule_module.result["commands"] def _create_target(self): - """ create the XML target_elt """ - target_elt = self.pfsense.new_element('rule') + """create the XML target_elt""" + target_elt = self.pfsense.new_element("rule") return target_elt def _delete_associated_rule(self, ruleid, interface=None): - if ruleid is None or ruleid == '' or ruleid == 'pass': + if ruleid is None or ruleid == "" or ruleid == "pass": return if interface is None: - interface = self.params['interface'] + interface = self.params["interface"] self.pfsense_rule_module = PFSenseRuleModule(self.module, self.pfsense) params = dict() - if self.params['descr'] is None: - params['name'] = 'NAT ' + if self.params["descr"] is None: + params["name"] = "NAT " else: - params['name'] = 'NAT ' + self.params['descr'] - params['interface'] = interface - params['state'] = 'absent' - params['associated-rule-id'] = ruleid + params["name"] = "NAT " + self.params["descr"] + params["interface"] = interface + params["state"] = "absent" + params["associated-rule-id"] = ruleid self.pfsense_rule_module.run(params) - self.result['commands'] += self.pfsense_rule_module.result['commands'] + self.result["commands"] += self.pfsense_rule_module.result["commands"] def _find_rule_by_descr(self, descr): - """ find the XML target_elt """ + """find the XML target_elt""" for idx, rule_elt in enumerate(self.root_elt): - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue - if rule_elt.find('descr').text == descr: + if rule_elt.find("descr").text == descr: return (rule_elt, idx) return (None, None) def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" for rule_elt in self.root_elt: - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue - if rule_elt.find('descr').text == self.obj['descr']: + if rule_elt.find("descr").text == self.obj["descr"]: return rule_elt return None def _get_expected_rule_position(self): - """ get expected rule position in interface/floating """ - if self.before == 'bottom': + """get expected rule position in interface/floating""" + if self.before == "bottom": return len(self.root_elt) - elif self.after == 'top': + elif self.after == "top": return 0 elif self.after is not None: return self._get_rule_position(self.after) + 1 @@ -267,25 +319,29 @@ def _get_expected_rule_position(self): return -1 def _get_expected_rule_xml_index(self): - """ get expected rule index in xml """ - if self.before == 'bottom': + """get expected rule index in xml""" + if self.before == "bottom": return len(self.root_elt) - elif self.after == 'top': + elif self.after == "top": return 0 elif self.after is not None: found, i = self._find_rule_by_descr(self.after) if found is not None: return i + 1 else: - self.module.fail_json(msg='Failed to insert after rule=%s' % (self.after)) + self.module.fail_json( + msg="Failed to insert after rule=%s" % (self.after) + ) elif self.before is not None: found, i = self._find_rule_by_descr(self.before) if found is not None: return i else: - self.module.fail_json(msg='Failed to insert before rule=%s' % (self.before)) + self.module.fail_json( + msg="Failed to insert before rule=%s" % (self.before) + ) else: - found, i = self._find_rule_by_descr(self.obj['descr']) + found, i = self._find_rule_by_descr(self.obj["descr"]) if found is not None: return i return len(self.root_elt) @@ -293,57 +349,62 @@ def _get_expected_rule_xml_index(self): @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - return ['disabled', 'nordr', 'nosync', 'natreflection'] + """returns the list of params to remove if they are not set""" + return ["disabled", "nordr", "nosync", "natreflection"] def _get_rule_position(self, descr=None, fail=True): - """ get rule position in interface/floating """ + """get rule position in interface/floating""" if descr is None: - descr = self.obj['descr'] + descr = self.obj["descr"] (res, idx) = self._find_rule_by_descr(descr) if fail and res is None: - self.module.fail_json(msg='Failed to find rule=%s' % (descr)) + self.module.fail_json(msg="Failed to find rule=%s" % (descr)) return idx def _insert(self, rule_elt): - """ insert rule into xml """ + """insert rule into xml""" rule_xml_idx = self._get_expected_rule_xml_index() self.root_elt.insert(rule_xml_idx, rule_elt) def _pre_remove_target_elt(self): - """ processing before removing elt """ - ruleid_elt = self.target_elt.find('associated-rule-id') + """processing before removing elt""" + ruleid_elt = self.target_elt.find("associated-rule-id") if ruleid_elt is not None: self._delete_associated_rule(ruleid_elt.text) def _set_associated_rule(self, before=None): - """ manage changes to the associated rule """ + """manage changes to the associated rule""" if before is None: - if self.params['associated_rule'] == 'associated' or self.params['associated_rule'] == 'unassociated': + if ( + self.params["associated_rule"] == "associated" + or self.params["associated_rule"] == "unassociated" + ): self._create_associated_rule() else: - if self.params['associated_rule'] == 'associated': - if before['associated-rule-id'].startswith('nat_'): - if self.obj['interface'] != before['interface']: - self._delete_associated_rule(before['associated-rule-id'], before['interface']) + if self.params["associated_rule"] == "associated": + if before["associated-rule-id"].startswith("nat_"): + if self.obj["interface"] != before["interface"]: + self._delete_associated_rule( + before["associated-rule-id"], before["interface"] + ) else: - self.obj['associated-rule-id'] = before['associated-rule-id'] + self.obj["associated-rule-id"] = before["associated-rule-id"] return self._create_associated_rule() - elif before['associated-rule-id'].startswith('nat_'): - self._delete_associated_rule(before['associated-rule-id']) + elif before["associated-rule-id"].startswith("nat_"): + self._delete_associated_rule(before["associated-rule-id"]) def _update_rule_position(self, rule_elt): - """ move rule in xml if required """ + """move rule in xml if required""" current_position = self._get_rule_position() expected_position = self._get_expected_rule_position() if current_position == expected_position: self.position_changed = False return False - self.diff['before']['position'] = current_position - self.diff['after']['position'] = expected_position + self.diff["before"]["position"] = current_position + self.diff["after"]["position"] = expected_position self.root_elt.remove(rule_elt) self._insert(rule_elt) self.position_changed = True @@ -353,113 +414,176 @@ def _update_rule_position(self, rule_elt): # run # def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell('''require_once("filter.inc"); -if (filter_configure() == 0) { clear_subsystem_dirty('natconf'); clear_subsystem_dirty('filter'); }''') + """make the target pfsense reload""" + return self.pfsense.phpshell( + """require_once("filter.inc"); +if (filter_configure() == 0) { clear_subsystem_dirty('natconf'); clear_subsystem_dirty('filter'); }""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.obj['descr']) + """return obj's name""" + return "'{0}'".format(self.obj["descr"]) @staticmethod def fassociate(value): - """ associated-rule-id value formatting function """ - if value is None or value == '': - return 'none' + """associated-rule-id value formatting function""" + if value is None or value == "": + return "none" - if value == 'pass': - return 'pass' + if value == "pass": + return "pass" - return 'associated' + return "associated" @staticmethod def fnatreflection(value): - """ natreflection value formatting function """ - if value is None or value == 'none': + """natreflection value formatting function""" + if value is None or value == "none": return "'system-default'" return value @staticmethod def fprotocol(value): - """ protocol value formatting function """ - if value is None or value == 'none': - return 'any' + """protocol value formatting function""" + if value is None or value == "none": + return "any" return value def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" fafter = self._obj_to_log_fields(self.obj) if before is None: - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'nordr', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'interface') + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field( + self.params, "nordr", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.params, "interface") if self.pfsense.is_at_least_2_5_0(): - values += self.format_cli_field(self.params, 'ipprotocol', default='inet') - values += self.format_cli_field(self.params, 'protocol', default='tcp') - values += self.format_cli_field(self.params, 'source') - values += self.format_cli_field(self.params, 'destination') - values += self.format_cli_field(self.params, 'target') - values += self.format_cli_field(self.params, 'natreflection', default='system-default') - values += self.format_cli_field(self.params, 'associated_rule', default='associated') - values += self.format_cli_field(self.params, 'nosync', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'after') - values += self.format_cli_field(self.params, 'before') + values += self.format_cli_field( + self.params, "ipprotocol", default="inet" + ) + values += self.format_cli_field(self.params, "protocol", default="tcp") + values += self.format_cli_field(self.params, "source") + values += self.format_cli_field(self.params, "destination") + values += self.format_cli_field(self.params, "target") + values += self.format_cli_field( + self.params, "natreflection", default="system-default" + ) + values += self.format_cli_field( + self.params, "associated_rule", default="associated" + ) + values += self.format_cli_field( + self.params, "nosync", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.params, "after") + values += self.format_cli_field(self.params, "before") else: fbefore = self._obj_to_log_fields(before) - fafter['before'] = self.before - fafter['after'] = self.after + fafter["before"] = self.before + fafter["after"] = self.after - values += self.format_updated_cli_field(self.obj, before, 'disabled', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'nordr', fvalue=self.fvalue_bool, default=False, add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'interface', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, + before, + "disabled", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "nordr", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) + values += self.format_updated_cli_field( + fafter, fbefore, "interface", add_comma=(values) + ) if self.pfsense.is_at_least_2_5_0(): - values += self.format_updated_cli_field(self.obj, before, 'ipprotocol', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'protocol', fvalue=self.fprotocol, add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'source', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'destination', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'target', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "ipprotocol", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "protocol", fvalue=self.fprotocol, add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "source", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "destination", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "target", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "natreflection", + fvalue=self.fnatreflection, + default="system-default", + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, + before, + "associated-rule-id", + fvalue=self.fassociate, + fname="associated_rule", + add_comma=(values), + ) values += self.format_updated_cli_field( - self.obj, before, 'natreflection', fvalue=self.fnatreflection, default='system-default', add_comma=(values) + self.obj, + before, + "nosync", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), ) - values += self.format_updated_cli_field(self.obj, before, 'associated-rule-id', fvalue=self.fassociate, fname='associated_rule', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'nosync', fvalue=self.fvalue_bool, default=False, add_comma=(values)) if self.position_changed: - values += self.format_updated_cli_field(fafter, {}, 'after', log_none=False, add_comma=(values)) - values += self.format_updated_cli_field(fafter, {}, 'before', log_none=False, add_comma=(values)) + values += self.format_updated_cli_field( + fafter, {}, "after", log_none=False, add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, {}, "before", log_none=False, add_comma=(values) + ) return values @staticmethod def _obj_address_to_log_field(rule, addr): - """ return formated address from dict """ - field = '' + """return formated address from dict""" + field = "" if isinstance(rule[addr], dict): - if 'any' in rule[addr]: - field = 'any' - if 'address' in rule[addr]: - field = rule[addr]['address'] - if 'port' in rule[addr]: + if "any" in rule[addr]: + field = "any" + if "address" in rule[addr]: + field = rule[addr]["address"] + if "port" in rule[addr]: if field: - field += ':' - field += rule[addr]['port'] + field += ":" + field += rule[addr]["port"] else: field = rule[addr] return field def _obj_to_log_fields(self, rule): - """ return formated source and destination from dict """ + """return formated source and destination from dict""" res = {} - res['source'] = self._obj_address_to_log_field(rule, 'source') - res['destination'] = self._obj_address_to_log_field(rule, 'destination') - res['target'] = rule['target'] - if 'local-port' in rule: - res['target'] += ':' + rule['local-port'] - res['interface'] = self.pfsense.get_interface_display_name(rule['interface']) + res["source"] = self._obj_address_to_log_field(rule, "source") + res["destination"] = self._obj_address_to_log_field(rule, "destination") + res["target"] = rule["target"] + if "local-port" in rule: + res["target"] += ":" + rule["local-port"] + res["interface"] = self.pfsense.get_interface_display_name(rule["interface"]) return res diff --git a/plugins/module_utils/openvpn_client.py b/plugins/module_utils/openvpn_client.py index ae5c6186..20372195 100644 --- a/plugins/module_utils/openvpn_client.py +++ b/plugins/module_utils/openvpn_client.py @@ -5,67 +5,86 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import base64 import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) OPENVPN_CLIENT_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - mode=dict(default='p2p_tls', required=False, choices=['p2p_tls', 'p2p_shared_key']), - authmode=dict(default=list(), required=False, type='list', elements='str'), - state=dict(default='present', choices=['present', 'absent']), - custom_options=dict(default=None, required=False, type='str'), - disable=dict(default=False, required=False, type='bool'), - interface=dict(default='wan', required=False, type='str'), - server_addr=dict(required=True, type='str'), - server_port=dict(default=1194, required=False, type='int'), - protocol=dict(default='UDP4', required=False, choices=['UDP4', 'TCP4']), - dev_mode=dict(default='tun', required=False, choices=['tun', 'tap']), - tls=dict(required=False, type='str'), - tls_type=dict(default='auth', required=False, choices=['auth', 'crypt']), - ca=dict(required=False, type='str'), - crl=dict(required=False, type='str'), - cert=dict(required=False, type='str'), - cert_depth=dict(default=1, required=False, type='int'), - strictusercn=dict(default=False, required=False, type='bool'), - shared_key=dict(required=False, type='str', no_log=True), - dh_length=dict(default=2048, required=False, type='int'), - ecdh_curve=dict(default='none', required=False, choices=['none', 'prime256v1', 'secp384r1', 'secp521r1']), - ncp_enable=dict(default=False, required=False, type='bool'), + name=dict(required=True, type="str"), + mode=dict(default="p2p_tls", required=False, choices=["p2p_tls", "p2p_shared_key"]), + authmode=dict(default=list(), required=False, type="list", elements="str"), + state=dict(default="present", choices=["present", "absent"]), + custom_options=dict(default=None, required=False, type="str"), + disable=dict(default=False, required=False, type="bool"), + interface=dict(default="wan", required=False, type="str"), + server_addr=dict(required=True, type="str"), + server_port=dict(default=1194, required=False, type="int"), + protocol=dict(default="UDP4", required=False, choices=["UDP4", "TCP4"]), + dev_mode=dict(default="tun", required=False, choices=["tun", "tap"]), + tls=dict(required=False, type="str"), + tls_type=dict(default="auth", required=False, choices=["auth", "crypt"]), + ca=dict(required=False, type="str"), + crl=dict(required=False, type="str"), + cert=dict(required=False, type="str"), + cert_depth=dict(default=1, required=False, type="int"), + strictusercn=dict(default=False, required=False, type="bool"), + shared_key=dict(required=False, type="str", no_log=True), + dh_length=dict(default=2048, required=False, type="int"), + ecdh_curve=dict( + default="none", + required=False, + choices=["none", "prime256v1", "secp384r1", "secp521r1"], + ), + ncp_enable=dict(default=False, required=False, type="bool"), # ncp_ciphers=dict(default=list('AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'), required=False, # choices=['AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'], type='list', elements='str'), - data_ciphers=dict(default=None, required=False, choices=['AES-256-CBC', 'AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'], type='list', elements='str'), - data_ciphers_fallback=dict(default='AES-256-CBC', required=False, choices=['AES-256-CBC', 'AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305']), - digest=dict(default='SHA256', required=False, type='str'), - tunnel_network=dict(default='', required=False, type='str'), - tunnel_networkv6=dict(default='', required=False, type='str'), - remote_network=dict(default='', required=False, type='str'), - remote_networkv6=dict(default='', required=False, type='str'), - gwredir=dict(default=False, required=False, type='bool'), - gwredir6=dict(default=False, required=False, type='bool'), - maxclients=dict(default=None, required=False, type='int'), - compression=dict(default='adaptive', required=False, choices=['adaptive', '']), - compression_push=dict(default=False, required=False, type='bool'), - passtos=dict(default=False, required=False, type='bool'), - client2client=dict(default=False, required=False, type='bool'), - dynamic_ip=dict(default=False, required=False, type='bool'), - topology=dict(default='subnet', required=False, choices=['net30', 'subnet']), - dns_domain=dict(default='', required=False, type='str'), - dns_client1=dict(default='', required=False, type='str'), - dns_client2=dict(default='', required=False, type='str'), - dns_client3=dict(default='', required=False, type='str'), - dns_client4=dict(default='', required=False, type='str'), - push_register_dns=dict(default=False, required=False, type='bool'), - create_gw=dict(default='both', required=False, choices=['both', 'v4only', 'v6only']), - verbosity_level=dict(default=3, required=False, type='int'), + data_ciphers=dict( + default=None, + required=False, + choices=["AES-256-CBC", "AES-256-GCM", "AES-128-GCM", "CHACHA20-POLY1305"], + type="list", + elements="str", + ), + data_ciphers_fallback=dict( + default="AES-256-CBC", + required=False, + choices=["AES-256-CBC", "AES-256-GCM", "AES-128-GCM", "CHACHA20-POLY1305"], + ), + digest=dict(default="SHA256", required=False, type="str"), + tunnel_network=dict(default="", required=False, type="str"), + tunnel_networkv6=dict(default="", required=False, type="str"), + remote_network=dict(default="", required=False, type="str"), + remote_networkv6=dict(default="", required=False, type="str"), + gwredir=dict(default=False, required=False, type="bool"), + gwredir6=dict(default=False, required=False, type="bool"), + maxclients=dict(default=None, required=False, type="int"), + compression=dict(default="adaptive", required=False, choices=["adaptive", ""]), + compression_push=dict(default=False, required=False, type="bool"), + passtos=dict(default=False, required=False, type="bool"), + client2client=dict(default=False, required=False, type="bool"), + dynamic_ip=dict(default=False, required=False, type="bool"), + topology=dict(default="subnet", required=False, choices=["net30", "subnet"]), + dns_domain=dict(default="", required=False, type="str"), + dns_client1=dict(default="", required=False, type="str"), + dns_client2=dict(default="", required=False, type="str"), + dns_client3=dict(default="", required=False, type="str"), + dns_client4=dict(default="", required=False, type="str"), + push_register_dns=dict(default=False, required=False, type="bool"), + create_gw=dict( + default="both", required=False, choices=["both", "v4only", "v6only"] + ), + verbosity_level=dict(default=3, required=False, type="int"), ) OPENVPN_CLIENT_REQUIRED_IF = [ - ['mode', 'p2p_tls', ['ca']], - ['mode', 'p2p_shared_key', ['shared_key']], + ["mode", "p2p_tls", ["ca"]], + ["mode", "p2p_shared_key", ["shared_key"]], ] OPENVPN_CLIENT_PHP_COMMAND_PREFIX = """ @@ -73,19 +92,25 @@ $ovpn = config_get_path('openvpn/openvpn-client')[{idx}]; """ -OPENVPN_CLIENT_PHP_COMMAND_SET = OPENVPN_CLIENT_PHP_COMMAND_PREFIX + """ +OPENVPN_CLIENT_PHP_COMMAND_SET = ( + OPENVPN_CLIENT_PHP_COMMAND_PREFIX + + """ openvpn_resync('client',$ovpn); """ +) -OPENVPN_CLIENT_PHP_COMMAND_DEL = OPENVPN_CLIENT_PHP_COMMAND_PREFIX + """ +OPENVPN_CLIENT_PHP_COMMAND_DEL = ( + OPENVPN_CLIENT_PHP_COMMAND_PREFIX + + """ openvpn_delete($ovpn); unset($ovpn); openvpn_resync('client',$ovpn); """ +) class PFSenseOpenVPNClientModule(PFSenseModuleBase): - """ module managing pfSense OpenVPN configuration """ + """module managing pfSense OpenVPN configuration""" ############################## # init @@ -93,12 +118,14 @@ class PFSenseOpenVPNClientModule(PFSenseModuleBase): def __init__(self, module, pfsense=None): super(PFSenseOpenVPNClientModule, self).__init__(module, pfsense) self.name = "pfsense_openvpn" - self.root_elt = self.pfsense.get_element('openvpn', create_node=True) + self.root_elt = self.pfsense.get_element("openvpn", create_node=True) self.obj = dict() - cmd = ('require_once("openvpn.inc");;' - '$digestlist = openvpn_get_digestlist();' - 'echo json_encode($digestlist);') + cmd = ( + 'require_once("openvpn.inc");;' + "$digestlist = openvpn_get_digestlist();" + "echo json_encode($digestlist);" + ) self.digestlist = self.pfsense.php(cmd) ############################## @@ -111,116 +138,143 @@ def _get_digest_name(self, digest: str): self.module.fail_json(msg=f"Invalid digest '{digest}'") def _params_to_obj(self): - """ return dict from module params """ + """return dict from module params""" obj = dict() - obj['description'] = self.params['name'] - if self.params['state'] == 'present': - obj['custom_options'] = self.params['custom_options'] - self._get_ansible_param_bool(obj, 'disable') - self._get_ansible_param_bool(obj, 'strictusercn') - obj['mode'] = self.params['mode'] - obj['dev_mode'] = self.params['dev_mode'] - obj['interface'] = self.params['interface'] - obj['protocol'] = self.params['protocol'] - obj['server_addr'] = self.params['server_addr'] - obj['server_port'] = str(self.params['server_port']) - self._get_ansible_param(obj, 'maxclients') - obj['verbosity_level'] = str(self.params['verbosity_level']) - obj['data_ciphers_fallback'] = self.params['data_ciphers_fallback'] - obj['data_ciphers'] = ",".join(self.params['data_ciphers']) - self._get_ansible_param_bool(obj, 'ncp_enable', 'enabled') - self._get_ansible_param_bool(obj, 'gwredir') - self._get_ansible_param_bool(obj, 'gwredirr6') - self._get_ansible_param_bool(obj, 'compression_push') - self._get_ansible_param_bool(obj, 'passtos') - self._get_ansible_param_bool(obj, 'client2client') - self._get_ansible_param_bool(obj, 'dynamic_ip') - self._get_ansible_param_bool(obj, 'push_register_dns') - obj['digest'] = self._get_digest_name(self.params['digest']) - obj['tunnel_network'] = self.params['tunnel_network'] - obj['tunnel_networkv6'] = self.params['tunnel_networkv6'] - obj['remote_network'] = self.params['remote_network'] - obj['remote_networkv6'] = self.params['remote_networkv6'] - obj['compression'] = self.params['compression'] - obj['topology'] = self.params['topology'] - obj['create_gw'] = self.params['create_gw'] - - if 'user' in self.params['mode']: - obj['authmode'] = ",".join(self.params['authmode']) - - if 'tls' in self.params['mode']: + obj["description"] = self.params["name"] + if self.params["state"] == "present": + obj["custom_options"] = self.params["custom_options"] + self._get_ansible_param_bool(obj, "disable") + self._get_ansible_param_bool(obj, "strictusercn") + obj["mode"] = self.params["mode"] + obj["dev_mode"] = self.params["dev_mode"] + obj["interface"] = self.params["interface"] + obj["protocol"] = self.params["protocol"] + obj["server_addr"] = self.params["server_addr"] + obj["server_port"] = str(self.params["server_port"]) + self._get_ansible_param(obj, "maxclients") + obj["verbosity_level"] = str(self.params["verbosity_level"]) + obj["data_ciphers_fallback"] = self.params["data_ciphers_fallback"] + obj["data_ciphers"] = ",".join(self.params["data_ciphers"]) + self._get_ansible_param_bool(obj, "ncp_enable", "enabled") + self._get_ansible_param_bool(obj, "gwredir") + self._get_ansible_param_bool(obj, "gwredirr6") + self._get_ansible_param_bool(obj, "compression_push") + self._get_ansible_param_bool(obj, "passtos") + self._get_ansible_param_bool(obj, "client2client") + self._get_ansible_param_bool(obj, "dynamic_ip") + self._get_ansible_param_bool(obj, "push_register_dns") + obj["digest"] = self._get_digest_name(self.params["digest"]) + obj["tunnel_network"] = self.params["tunnel_network"] + obj["tunnel_networkv6"] = self.params["tunnel_networkv6"] + obj["remote_network"] = self.params["remote_network"] + obj["remote_networkv6"] = self.params["remote_networkv6"] + obj["compression"] = self.params["compression"] + obj["topology"] = self.params["topology"] + obj["create_gw"] = self.params["create_gw"] + + if "user" in self.params["mode"]: + obj["authmode"] = ",".join(self.params["authmode"]) + + if "tls" in self.params["mode"]: # Find the caref id for the named CA if self.params is not None: - ca_elt = self.pfsense.find_ca_elt(self.params['ca']) + ca_elt = self.pfsense.find_ca_elt(self.params["ca"]) if ca_elt is None: - self.module.fail_json(msg='%s is not a valid certificate authority' % (self.params['ca'])) - obj['caref'] = ca_elt.find('refid').text + self.module.fail_json( + msg="%s is not a valid certificate authority" + % (self.params["ca"]) + ) + obj["caref"] = ca_elt.find("refid").text # Find the crlref id for the named CRL if any - if self.params['crl'] is not None: - crl_elt = self.pfsense.find_crl_elt(self.params['crl']) + if self.params["crl"] is not None: + crl_elt = self.pfsense.find_crl_elt(self.params["crl"]) if crl_elt is None: - self.module.fail_json(msg='%s is not a valid certificate revocation list' % (self.params['crl'])) - obj['crlref'] = crl_elt.find('refid').text + self.module.fail_json( + msg="%s is not a valid certificate revocation list" + % (self.params["crl"]) + ) + obj["crlref"] = crl_elt.find("refid").text else: - obj['crlref'] = '' + obj["crlref"] = "" # Find the certref id for the named certificate if any - if self.params['cert'] is not None: - cert_elt = self.pfsense.find_cert_elt(self.params['cert']) + if self.params["cert"] is not None: + cert_elt = self.pfsense.find_cert_elt(self.params["cert"]) if cert_elt is None: - self.module.fail_json(msg='%s is not a valid certificate' % (self.params['cert'])) - obj['certref'] = cert_elt.find('refid').text + self.module.fail_json( + msg="%s is not a valid certificate" % (self.params["cert"]) + ) + obj["certref"] = cert_elt.find("refid").text - if self.params['tls'] is not None: - obj['tls'] = self.params['tls'] - obj['tls_type'] = self.params['tls_type'] + if self.params["tls"] is not None: + obj["tls"] = self.params["tls"] + obj["tls_type"] = self.params["tls_type"] - if self.params['mode'] == 'p2p_shared_key': - obj['shared_key'] = self.params['shared_key'] + if self.params["mode"] == "p2p_shared_key": + obj["shared_key"] = self.params["shared_key"] return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - self.pfsense.validate_string(params['name'], 'openvpn') + self.pfsense.validate_string(params["name"], "openvpn") - if params['state'] == 'absent': + if params["state"] == "absent": return True # tls is not valid for p2p_shared_key - if params['mode'] == 'p2p_shared_key' and params['tls'] is not None: - self.module.fail_json(msg='tls parameter is not valied with p2p_shared_key mode.') + if params["mode"] == "p2p_shared_key" and params["tls"] is not None: + self.module.fail_json( + msg="tls parameter is not valied with p2p_shared_key mode." + ) # check tunnel_networks - can be network alias or non-strict IP CIDR network - self.pfsense.validate_openvpn_tunnel_network(params.get('tunnel_network'), 'ipv4') - self.pfsense.validate_openvpn_tunnel_network(params.get('tunnel_network6'), 'ipv6') + self.pfsense.validate_openvpn_tunnel_network( + params.get("tunnel_network"), "ipv4" + ) + self.pfsense.validate_openvpn_tunnel_network( + params.get("tunnel_network6"), "ipv6" + ) # Check auth clients - if len(params['authmode']) > 0: - system = self.pfsense.get_element('system') - for authsrv in params['authmode']: + if len(params["authmode"]) > 0: + system = self.pfsense.get_element("system") + for authsrv in params["authmode"]: if len(system.findall("authclient[name='{0}']".format(authsrv))) == 0: - self.module.fail_json(msg='Cannot find authentication client {0}.'.format(authsrv)) + self.module.fail_json( + msg="Cannot find authentication client {0}.".format(authsrv) + ) # validate key - for param in ['shared_key', 'tls']: + for param in ["shared_key", "tls"]: if params[param] is not None: key = params[param] - if key == 'generate': + if key == "generate": # generate during params_to_obj pass - elif re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL): + elif re.search( + "^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$", + key, + flags=re.MULTILINE | re.DOTALL, + ): params[param] = base64.b64encode(key.encode()).decode() else: key_decoded = base64.b64decode(key.encode()).decode() - if not re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', - key_decoded, flags=re.MULTILINE | re.DOTALL): - self.module.fail_json(msg='Could not recognize {0} key format: {1}'.format(param, key_decoded)) + if not re.search( + "^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$", + key_decoded, + flags=re.MULTILINE | re.DOTALL, + ): + self.module.fail_json( + msg="Could not recognize {0} key format: {1}".format( + param, key_decoded + ) + ) def _nextvpnid(self): - """ find next available vpnid """ + """find next available vpnid""" vpnid = 1 while len(self.root_elt.findall("*[vpnid='{0}']".format(vpnid))) != 0: vpnid += 1 @@ -229,8 +283,8 @@ def _nextvpnid(self): ############################## # XML processing # - def _find_openvpn_client(self, value, field='description'): - """ return openvpn-client element """ + def _find_openvpn_client(self, value, field="description"): + """return openvpn-client element""" i = 0 for elt in self.root_elt: field_elt = elt.find(field) @@ -246,34 +300,42 @@ def _find_last_openvpn_idx(self): return i def _copy_and_update_target(self): - """ update the XML target_elt """ - (before, changed) = super(PFSenseOpenVPNClientModule, self)._copy_and_update_target() + """update the XML target_elt""" + (before, changed) = super( + PFSenseOpenVPNClientModule, self + )._copy_and_update_target() if not changed: - self.diff['after'] = self.obj + self.diff["after"] = self.obj return (before, changed) def _create_target(self): - """ create the XML target_elt """ - target_elt = self.pfsense.new_element('openvpn-client') - self.obj['vpnid'] = self._nextvpnid() - self.diff['before'] = '' - self.diff['after'] = self.obj - self.result['changed'] = True + """create the XML target_elt""" + target_elt = self.pfsense.new_element("openvpn-client") + self.obj["vpnid"] = self._nextvpnid() + self.diff["before"] = "" + self.diff["after"] = self.obj + self.result["changed"] = True self.idx = self._find_last_openvpn_idx() return target_elt def _find_target(self): - """ find the XML target_elt """ - (target_elt, self.idx) = self._find_openvpn_client(self.obj['description']) - for param in ['shared_key', 'tls']: + """find the XML target_elt""" + (target_elt, self.idx) = self._find_openvpn_client(self.obj["description"]) + for param in ["shared_key", "tls"]: current_elt = self.pfsense.get_element(param, target_elt) - if self.params[param] == 'generate': + if self.params[param] == "generate": if current_elt is None: - (dummy, key, stderr) = self.module.run_command('/usr/local/sbin/openvpn --genkey secret /dev/stdout') + (dummy, key, stderr) = self.module.run_command( + "/usr/local/sbin/openvpn --genkey secret /dev/stdout" + ) if stderr != "": - self.module.fail_json(msg='generate for "{0}" secret key: {1}'.format(param, stderr)) + self.module.fail_json( + msg='generate for "{0}" secret key: {1}'.format( + param, stderr + ) + ) self.obj[param] = base64.b64encode(key.encode()).decode() self.result[param] = self.obj[param] else: @@ -281,36 +343,42 @@ def _find_target(self): return target_elt def _remove_target_elt(self): - """ delete target_elt from xml """ + """delete target_elt from xml""" super(PFSenseOpenVPNClientModule, self)._remove_target_elt() - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) ############################## # run # def _remove(self): - """ delete obj """ - self.diff['after'] = '' - self.diff['before'] = '' + """delete obj""" + self.diff["after"] = "" + self.diff["before"] = "" super(PFSenseOpenVPNClientModule, self)._remove() - return self.pfsense.phpshell(OPENVPN_CLIENT_PHP_COMMAND_DEL.format(idx=self.idx)) + return self.pfsense.phpshell( + OPENVPN_CLIENT_PHP_COMMAND_DEL.format(idx=self.idx) + ) def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell(OPENVPN_CLIENT_PHP_COMMAND_SET.format(idx=self.idx)) + """make the target pfsense reload""" + return self.pfsense.phpshell( + OPENVPN_CLIENT_PHP_COMMAND_SET.format(idx=self.idx) + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'" + self.obj['description'] + "'" + """return obj's name""" + return "'" + self.obj["description"] + "'" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'description') + values += self.format_cli_field(self.obj, "description") else: - values += self.format_updated_cli_field(self.obj, before, 'description', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "description", add_comma=(values) + ) return values diff --git a/plugins/module_utils/openvpn_override.py b/plugins/module_utils/openvpn_override.py index e099ed09..6f024e74 100644 --- a/plugins/module_utils/openvpn_override.py +++ b/plugins/module_utils/openvpn_override.py @@ -5,58 +5,70 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) OPENVPN_OVERRIDE_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - state=dict(default='present', choices=['present', 'absent']), - server_list=dict(default=None, type='list', elements='str'), - disable=dict(default=False, required=False, type='bool'), - descr=dict(default=None, required=False, type='str'), - block=dict(default=False, required=False, type='bool'), - tunnel_network=dict(default=None, required=False, type='str'), - tunnel_networkv6=dict(default=None, required=False, type='str'), - local_network=dict(default=None, required=False, type='str'), - local_networkv6=dict(default=None, required=False, type='str'), - remote_network=dict(default=None, required=False, type='str'), - remote_networkv6=dict(default=None, required=False, type='str'), - gwredir=dict(default=False, required=False, type='bool'), - push_reset=dict(default=False, required=False, type='bool'), - netbios_enable=dict(default=False, required=False, type='bool'), - netbios_ntype=dict(required=False, choices=['none', 'b-node', 'p-node', 'm-node', 'h-node']), - netbios_scope=dict(required=False, type='str'), - wins_server_enable=dict(default=False, required=False, type='bool'), - custom_options=dict(default=None, required=False, type='str'), + name=dict(required=True, type="str"), + state=dict(default="present", choices=["present", "absent"]), + server_list=dict(default=None, type="list", elements="str"), + disable=dict(default=False, required=False, type="bool"), + descr=dict(default=None, required=False, type="str"), + block=dict(default=False, required=False, type="bool"), + tunnel_network=dict(default=None, required=False, type="str"), + tunnel_networkv6=dict(default=None, required=False, type="str"), + local_network=dict(default=None, required=False, type="str"), + local_networkv6=dict(default=None, required=False, type="str"), + remote_network=dict(default=None, required=False, type="str"), + remote_networkv6=dict(default=None, required=False, type="str"), + gwredir=dict(default=False, required=False, type="bool"), + push_reset=dict(default=False, required=False, type="bool"), + netbios_enable=dict(default=False, required=False, type="bool"), + netbios_ntype=dict( + required=False, choices=["none", "b-node", "p-node", "m-node", "h-node"] + ), + netbios_scope=dict(required=False, type="str"), + wins_server_enable=dict(default=False, required=False, type="bool"), + custom_options=dict(default=None, required=False, type="str"), ) -OPENVPN_OVERRIDE_REQUIRED_IF = [ -] +OPENVPN_OVERRIDE_REQUIRED_IF = [] OPENVPN_OVERRIDE_PHP_COMMAND_PREFIX = """ require_once('openvpn.inc'); $csc = config_get_path('openvpn/openvpn-csc')[{idx}]; """ -OPENVPN_OVERRIDE_PHP_COMMAND_SET = OPENVPN_OVERRIDE_PHP_COMMAND_PREFIX + """ +OPENVPN_OVERRIDE_PHP_COMMAND_SET = ( + OPENVPN_OVERRIDE_PHP_COMMAND_PREFIX + + """ openvpn_resync_csc($csc); """ +) -OPENVPN_OVERRIDE_PHP_COMMAND_DEL = OPENVPN_OVERRIDE_PHP_COMMAND_PREFIX + """ +OPENVPN_OVERRIDE_PHP_COMMAND_DEL = ( + OPENVPN_OVERRIDE_PHP_COMMAND_PREFIX + + """ openvpn_delete_csc($csc); unset($csc); """ +) class PFSenseOpenVPNOverrideModule(PFSenseModuleBase): - """ module managing pfSense OpenVPN Client Specific Overrides """ + """module managing pfSense OpenVPN Client Specific Overrides""" - from ansible_collections.pfsensible.core.plugins.module_utils.__impl.checks import validate_openvpn_tunnel_network + from ansible_collections.pfsensible.core.plugins.module_utils.__impl.checks import ( + validate_openvpn_tunnel_network, + ) @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return OPENVPN_OVERRIDE_ARGUMENT_SPEC ############################## @@ -65,86 +77,120 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseOpenVPNOverrideModule, self).__init__(module, pfsense) self.name = "pfsense_openvpn_override" - self.root_elt = self.pfsense.get_element('openvpn') - self.openvpn_csc_elt = self.root_elt.findall('openvpn-csc') + self.root_elt = self.pfsense.get_element("openvpn") + self.openvpn_csc_elt = self.root_elt.findall("openvpn-csc") self.obj = dict() ############################## # params processing # def _params_to_obj(self): - """ return dict from module params """ + """return dict from module params""" obj = dict() - obj['common_name'] = self.params['name'] - if self.params['state'] == 'present': + obj["common_name"] = self.params["name"] + if self.params["state"] == "present": # Find the ids for server names server_list = list() - if self.params['server_list'] is not None: - for server in self.params['server_list']: - vpnid = '' - if isinstance(server, int) or (isinstance(server, str) and server.isdigit()): - openvpn_server_elt = self.pfsense.find_elt('openvpn-server', str(server), 'vpnid', root_elt=self.root_elt) + if self.params["server_list"] is not None: + for server in self.params["server_list"]: + vpnid = "" + if isinstance(server, int) or ( + isinstance(server, str) and server.isdigit() + ): + openvpn_server_elt = self.pfsense.find_elt( + "openvpn-server", + str(server), + "vpnid", + root_elt=self.root_elt, + ) else: - openvpn_server_elt = self.pfsense.find_elt('openvpn-server', server, 'description', root_elt=self.root_elt) + openvpn_server_elt = self.pfsense.find_elt( + "openvpn-server", + server, + "description", + root_elt=self.root_elt, + ) if openvpn_server_elt is None: - self.module.fail_json(msg="Could not find openvpn server '%s'" % (server)) - vpnid = openvpn_server_elt.find('vpnid').text + self.module.fail_json( + msg="Could not find openvpn server '%s'" % (server) + ) + vpnid = openvpn_server_elt.find("vpnid").text server_list.append(vpnid) - obj['server_list'] = ','.join(server_list) - self.result['vpnids'] = server_list - - obj['custom_options'] = self.params['custom_options'] - obj['description'] = self.params['descr'] - self._get_ansible_param_bool(obj, 'disable') - self._get_ansible_param_bool(obj, 'block', force=True, value='yes') - self._get_ansible_param_bool(obj, 'gwredir', force=True, value='yes') + obj["server_list"] = ",".join(server_list) + self.result["vpnids"] = server_list + + obj["custom_options"] = self.params["custom_options"] + obj["description"] = self.params["descr"] + self._get_ansible_param_bool(obj, "disable") + self._get_ansible_param_bool(obj, "block", force=True, value="yes") + self._get_ansible_param_bool(obj, "gwredir", force=True, value="yes") if self.pfsense.config_version >= 23.4: - self._get_ansible_param_bool(obj, 'push_reset') + self._get_ansible_param_bool(obj, "push_reset") else: - self._get_ansible_param_bool(obj, 'push_reset', force=True, value='yes') - obj['tunnel_network'] = self.params['tunnel_network'] - obj['tunnel_networkv6'] = self.params['tunnel_networkv6'] - obj['local_network'] = self.params['local_network'] - obj['local_networkv6'] = self.params['local_networkv6'] - obj['remote_network'] = self.params['remote_network'] - obj['remote_networkv6'] = self.params['remote_networkv6'] - self._get_ansible_param_bool(obj, 'netbios_enable') - if self.params['netbios_enable']: - obj['netbios_ntype'] = self.params['netbios_ntype'] - obj['netbios_scope'] = str(self.params['netbios_scope']) - self._get_ansible_param(obj, 'netbios_scope') - self._get_ansible_param_bool(obj, 'wins_server_enable') + self._get_ansible_param_bool(obj, "push_reset", force=True, value="yes") + obj["tunnel_network"] = self.params["tunnel_network"] + obj["tunnel_networkv6"] = self.params["tunnel_networkv6"] + obj["local_network"] = self.params["local_network"] + obj["local_networkv6"] = self.params["local_networkv6"] + obj["remote_network"] = self.params["remote_network"] + obj["remote_networkv6"] = self.params["remote_networkv6"] + self._get_ansible_param_bool(obj, "netbios_enable") + if self.params["netbios_enable"]: + obj["netbios_ntype"] = self.params["netbios_ntype"] + obj["netbios_scope"] = str(self.params["netbios_scope"]) + self._get_ansible_param(obj, "netbios_scope") + self._get_ansible_param_bool(obj, "wins_server_enable") return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - self.pfsense.validate_string(params['name'], 'openvpn_override') + self.pfsense.validate_string(params["name"], "openvpn_override") - if params['state'] == 'absent': + if params["state"] == "absent": return True # check tunnel_networks - can be network alias or non-strict IP CIDR network - self.pfsense.validate_openvpn_tunnel_network(params.get('tunnel_network'), 'ipv4') - self.pfsense.validate_openvpn_tunnel_network(params.get('tunnel_network6'), 'ipv6') - - if params.get('local_network') and not self.pfsense.is_ipv4_network(params['local_network']): - self.module.fail_json(msg='A valid IPv4 network must be specified for local_network.') - if params.get('local_network6') and not self.pfsense.is_ipv6_network(params['local_networkv6']): - self.module.fail_json(msg='A valid IPv6 network must be specified for local_network6.') - if params.get('remote_network') and not self.pfsense.is_ipv4_network(params['remote_network']): - self.module.fail_json(msg='A valid IPv4 network must be specified for remote_network.') - if params.get('remote_network6') and not self.pfsense.is_ipv6_network(params['remote_networkv6']): - self.module.fail_json(msg='A valid IPv6 network must be specified for remote_network6.') + self.pfsense.validate_openvpn_tunnel_network( + params.get("tunnel_network"), "ipv4" + ) + self.pfsense.validate_openvpn_tunnel_network( + params.get("tunnel_network6"), "ipv6" + ) + + if params.get("local_network") and not self.pfsense.is_ipv4_network( + params["local_network"] + ): + self.module.fail_json( + msg="A valid IPv4 network must be specified for local_network." + ) + if params.get("local_network6") and not self.pfsense.is_ipv6_network( + params["local_networkv6"] + ): + self.module.fail_json( + msg="A valid IPv6 network must be specified for local_network6." + ) + if params.get("remote_network") and not self.pfsense.is_ipv4_network( + params["remote_network"] + ): + self.module.fail_json( + msg="A valid IPv4 network must be specified for remote_network." + ) + if params.get("remote_network6") and not self.pfsense.is_ipv6_network( + params["remote_networkv6"] + ): + self.module.fail_json( + msg="A valid IPv6 network must be specified for remote_network6." + ) ############################## # XML processing # - def _find_openvpn_csc(self, value, field='common_name'): - """ return openvpn-csc element """ + def _find_openvpn_csc(self, value, field="common_name"): + """return openvpn-csc element""" i = 0 for csc_elt in self.openvpn_csc_elt: field_elt = csc_elt.find(field) @@ -160,74 +206,80 @@ def _find_last_openvpn_idx(self): return i def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) if self._remove_deleted_params(): changed = True - self.diff['before'] = before + self.diff["before"] = before if changed: - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) - self.result['changed'] = True + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) + self.result["changed"] = True else: - self.diff['after'] = self.obj + self.diff["after"] = self.obj return (before, changed) def _create_target(self): - """ create the XML target_elt """ - target_elt = self.pfsense.new_element('openvpn-csc') - self.diff['before'] = '' - self.diff['after'] = self.obj - self.result['changed'] = True + """create the XML target_elt""" + target_elt = self.pfsense.new_element("openvpn-csc") + self.diff["before"] = "" + self.diff["after"] = self.obj + self.result["changed"] = True self.idx = self._find_last_openvpn_idx() return target_elt def _find_target(self): - """ find the XML target_elt """ - (target_elt, self.idx) = self._find_openvpn_csc(self.obj['common_name']) + """find the XML target_elt""" + (target_elt, self.idx) = self._find_openvpn_csc(self.obj["common_name"]) return target_elt def _get_params_to_remove(self): - """ returns the list of params to remove if they are not set """ + """returns the list of params to remove if they are not set""" params_to_remove = [] if self.pfsense.config_version >= 23.4: - params_to_remove.append('push_reset') + params_to_remove.append("push_reset") return params_to_remove def _remove_target_elt(self): - """ delete target_elt from xml """ + """delete target_elt from xml""" super(PFSenseOpenVPNOverrideModule, self)._remove_target_elt() - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) ############################## # run # def _remove(self): - """ delete obj """ - self.diff['after'] = '' - self.diff['before'] = '' + """delete obj""" + self.diff["after"] = "" + self.diff["before"] = "" super(PFSenseOpenVPNOverrideModule, self)._remove() - return self.pfsense.phpshell(OPENVPN_OVERRIDE_PHP_COMMAND_DEL.format(idx=self.idx)) + return self.pfsense.phpshell( + OPENVPN_OVERRIDE_PHP_COMMAND_DEL.format(idx=self.idx) + ) def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell(OPENVPN_OVERRIDE_PHP_COMMAND_SET.format(idx=self.idx)) + """make the target pfsense reload""" + return self.pfsense.phpshell( + OPENVPN_OVERRIDE_PHP_COMMAND_SET.format(idx=self.idx) + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'" + self.obj['common_name'] + "'" + """return obj's name""" + return "'" + self.obj["common_name"] + "'" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'common_name') - values += self.format_cli_field(self.obj, 'descr') + values += self.format_cli_field(self.obj, "common_name") + values += self.format_cli_field(self.obj, "descr") else: - values += self.format_updated_cli_field(self.obj, before, 'descr', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "descr", add_comma=(values) + ) return values diff --git a/plugins/module_utils/openvpn_server.py b/plugins/module_utils/openvpn_server.py index 20011efb..f30fe09b 100644 --- a/plugins/module_utils/openvpn_server.py +++ b/plugins/module_utils/openvpn_server.py @@ -5,79 +5,122 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type import base64 import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) OPENVPN_SERVER_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - mode=dict(type='str', choices=['p2p_tls', 'p2p_shared_key', 'server_tls', 'server_tls_user', 'server_user']), - authmode=dict(default=list(), required=False, type='list', elements='str'), - state=dict(default='present', choices=['present', 'absent']), - custom_options=dict(default=None, required=False, type='str'), - disable=dict(default=False, required=False, type='bool'), - interface=dict(default='wan', required=False, type='str'), - local_port=dict(default=1194, required=False, type='int'), - protocol=dict(default='UDP4', required=False, choices=['UDP4', 'TCP4']), - dev_mode=dict(default='tun', required=False, choices=['tun', 'tap']), - tls=dict(required=False, type='str'), - tls_type=dict(default='auth', required=False, choices=['auth', 'crypt']), - ca=dict(required=False, type='str'), - crl=dict(required=False, type='str'), - cert=dict(required=False, type='str'), - cert_depth=dict(default=1, required=False, type='int'), - strictusercn=dict(default=False, required=False, type='bool'), - remote_cert_tls=dict(default=False, required=False, type='bool'), - shared_key=dict(required=False, type='str', no_log=True), - dh_length=dict(default=2048, required=False, type='int'), - ecdh_curve=dict(default='none', required=False, choices=['none', 'prime256v1', 'secp384r1', 'secp521r1']), - ncp_enable=dict(default=True, required=False, type='bool'), + name=dict(required=True, type="str"), + mode=dict( + type="str", + choices=[ + "p2p_tls", + "p2p_shared_key", + "server_tls", + "server_tls_user", + "server_user", + ], + ), + authmode=dict(default=list(), required=False, type="list", elements="str"), + state=dict(default="present", choices=["present", "absent"]), + custom_options=dict(default=None, required=False, type="str"), + disable=dict(default=False, required=False, type="bool"), + interface=dict(default="wan", required=False, type="str"), + local_port=dict(default=1194, required=False, type="int"), + protocol=dict(default="UDP4", required=False, choices=["UDP4", "TCP4"]), + dev_mode=dict(default="tun", required=False, choices=["tun", "tap"]), + tls=dict(required=False, type="str"), + tls_type=dict(default="auth", required=False, choices=["auth", "crypt"]), + ca=dict(required=False, type="str"), + crl=dict(required=False, type="str"), + cert=dict(required=False, type="str"), + cert_depth=dict(default=1, required=False, type="int"), + strictusercn=dict(default=False, required=False, type="bool"), + remote_cert_tls=dict(default=False, required=False, type="bool"), + shared_key=dict(required=False, type="str", no_log=True), + dh_length=dict(default=2048, required=False, type="int"), + ecdh_curve=dict( + default="none", + required=False, + choices=["none", "prime256v1", "secp384r1", "secp521r1"], + ), + ncp_enable=dict(default=True, required=False, type="bool"), # ncp_ciphers=dict(default=list('AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'), required=False, # choices=['AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'], type='list', elements='str'), - data_ciphers=dict(default=['AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'], required=False, - choices=['AES-256-CBC', 'AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305'], type='list', elements='str'), - data_ciphers_fallback=dict(default='AES-256-CBC', required=False, choices=['AES-256-CBC', 'AES-256-GCM', 'AES-128-GCM', 'CHACHA20-POLY1305']), - digest=dict(default='SHA256', required=False, type='str'), - tunnel_network=dict(default='', required=False, type='str'), - tunnel_networkv6=dict(default='', required=False, type='str'), - local_network=dict(default='', required=False, type='str'), - local_networkv6=dict(default='', required=False, type='str'), - remote_network=dict(default='', required=False, type='str'), - remote_networkv6=dict(default='', required=False, type='str'), - gwredir=dict(default=False, required=False, type='bool'), - gwredir6=dict(default=False, required=False, type='bool'), - maxclients=dict(default=None, required=False, type='int'), - allow_compression=dict(default='no', required=False, choices=['no', 'asym', 'yes']), - compression=dict(default='', required=False, choices=['', 'none', 'stub', 'stub-v2', 'lz4', 'lz4-v2', 'lzo', 'noadapt', 'adaptive', 'yes', 'no']), - compression_push=dict(default=False, required=False, type='bool'), - passtos=dict(default=False, required=False, type='bool'), - client2client=dict(default=False, required=False, type='bool'), - dynamic_ip=dict(default=False, required=False, type='bool'), - topology=dict(default='subnet', required=False, choices=['net30', 'subnet']), - inactive_seconds=dict(default=0, required=False, type='int'), - keepalive_interval=dict(default=10, required=False, type='int'), - keepalive_timeout=dict(default=60, required=False, type='int'), - exit_notify=dict(default='none', required=False, choices=['none', '1', '2']), - dns_domain=dict(default='', required=False, type='str'), - dns_server1=dict(default='', required=False, type='str'), - dns_server2=dict(default='', required=False, type='str'), - dns_server3=dict(default='', required=False, type='str'), - dns_server4=dict(default='', required=False, type='str'), - push_register_dns=dict(default=False, required=False, type='bool'), - username_as_common_name=dict(default=False, required=False, type='bool'), - create_gw=dict(default='both', required=False, type='str', choices=['both', 'v4only', 'v6only']), - verbosity_level=dict(default=1, required=False, type='int'), + data_ciphers=dict( + default=["AES-256-GCM", "AES-128-GCM", "CHACHA20-POLY1305"], + required=False, + choices=["AES-256-CBC", "AES-256-GCM", "AES-128-GCM", "CHACHA20-POLY1305"], + type="list", + elements="str", + ), + data_ciphers_fallback=dict( + default="AES-256-CBC", + required=False, + choices=["AES-256-CBC", "AES-256-GCM", "AES-128-GCM", "CHACHA20-POLY1305"], + ), + digest=dict(default="SHA256", required=False, type="str"), + tunnel_network=dict(default="", required=False, type="str"), + tunnel_networkv6=dict(default="", required=False, type="str"), + local_network=dict(default="", required=False, type="str"), + local_networkv6=dict(default="", required=False, type="str"), + remote_network=dict(default="", required=False, type="str"), + remote_networkv6=dict(default="", required=False, type="str"), + gwredir=dict(default=False, required=False, type="bool"), + gwredir6=dict(default=False, required=False, type="bool"), + maxclients=dict(default=None, required=False, type="int"), + allow_compression=dict(default="no", required=False, choices=["no", "asym", "yes"]), + compression=dict( + default="", + required=False, + choices=[ + "", + "none", + "stub", + "stub-v2", + "lz4", + "lz4-v2", + "lzo", + "noadapt", + "adaptive", + "yes", + "no", + ], + ), + compression_push=dict(default=False, required=False, type="bool"), + passtos=dict(default=False, required=False, type="bool"), + client2client=dict(default=False, required=False, type="bool"), + dynamic_ip=dict(default=False, required=False, type="bool"), + topology=dict(default="subnet", required=False, choices=["net30", "subnet"]), + inactive_seconds=dict(default=0, required=False, type="int"), + keepalive_interval=dict(default=10, required=False, type="int"), + keepalive_timeout=dict(default=60, required=False, type="int"), + exit_notify=dict(default="none", required=False, choices=["none", "1", "2"]), + dns_domain=dict(default="", required=False, type="str"), + dns_server1=dict(default="", required=False, type="str"), + dns_server2=dict(default="", required=False, type="str"), + dns_server3=dict(default="", required=False, type="str"), + dns_server4=dict(default="", required=False, type="str"), + push_register_dns=dict(default=False, required=False, type="bool"), + username_as_common_name=dict(default=False, required=False, type="bool"), + create_gw=dict( + default="both", required=False, type="str", choices=["both", "v4only", "v6only"] + ), + verbosity_level=dict(default=1, required=False, type="int"), ) OPENVPN_SERVER_REQUIRED_IF = [ - ['state', 'present', ['mode']], - ['mode', 'p2p_tls', ['ca']], - ['mode', 'server_tls', ['ca']], - ['mode', 'server_tls_user', ['ca']], - ['mode', 'p2p_shared_key', ['shared_key']], + ["state", "present", ["mode"]], + ["mode", "p2p_tls", ["ca"]], + ["mode", "server_tls", ["ca"]], + ["mode", "server_tls_user", ["ca"]], + ["mode", "p2p_shared_key", ["shared_key"]], ] OPENVPN_SERVER_PHP_COMMAND_PREFIX = """ @@ -85,22 +128,28 @@ $ovpn = config_get_path('openvpn/openvpn-server')[{idx}]; """ -OPENVPN_SERVER_PHP_COMMAND_SET = OPENVPN_SERVER_PHP_COMMAND_PREFIX + """ +OPENVPN_SERVER_PHP_COMMAND_SET = ( + OPENVPN_SERVER_PHP_COMMAND_PREFIX + + """ openvpn_resync('server',$ovpn); openvpn_resync_csc_all(); """ +) -OPENVPN_SERVER_PHP_COMMAND_DEL = OPENVPN_SERVER_PHP_COMMAND_PREFIX + """ +OPENVPN_SERVER_PHP_COMMAND_DEL = ( + OPENVPN_SERVER_PHP_COMMAND_PREFIX + + """ openvpn_delete('server',$ovpn); """ +) class PFSenseOpenVPNServerModule(PFSenseModuleBase): - """ module managing pfSense OpenVPN configuration """ + """module managing pfSense OpenVPN configuration""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return OPENVPN_SERVER_ARGUMENT_SPEC ############################## @@ -109,12 +158,14 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseOpenVPNServerModule, self).__init__(module, pfsense) self.name = "pfsense_openvpn_server" - self.root_elt = self.pfsense.get_element('openvpn', create_node=True) + self.root_elt = self.pfsense.get_element("openvpn", create_node=True) self.obj = dict() - cmd = ('require_once("openvpn.inc");;' - '$digestlist = openvpn_get_digestlist();' - 'echo json_encode($digestlist);') + cmd = ( + 'require_once("openvpn.inc");;' + "$digestlist = openvpn_get_digestlist();" + "echo json_encode($digestlist);" + ) self.digestlist = self.pfsense.php(cmd) ############################## @@ -127,161 +178,228 @@ def _get_digest_name(self, digest: str): self.module.fail_json(msg=f"Invalid digest '{digest}'") def _params_to_obj(self): - """ return dict from module params """ + """return dict from module params""" obj = dict() - obj['description'] = self.params['name'] - if self.params['state'] == 'present': - obj['custom_options'] = self.params['custom_options'] - self._get_ansible_param_bool(obj, 'disable') - self._get_ansible_param_bool(obj, 'strictusercn') - self._get_ansible_param_bool(obj, 'remote_cert_tls') - obj['mode'] = self.params['mode'] - obj['dev_mode'] = self.params['dev_mode'] - obj['interface'] = self.params['interface'] - obj['protocol'] = self.params['protocol'] - obj['local_port'] = str(self.params['local_port']) - self._get_ansible_param(obj, 'maxclients') - obj['verbosity_level'] = str(self.params['verbosity_level']) - obj['data_ciphers_fallback'] = self.params['data_ciphers_fallback'] - obj['data_ciphers'] = ",".join(self.params['data_ciphers']) - self._get_ansible_param_bool(obj, 'ncp_enable', force=True, value='enabled', value_false='disabled') - self._get_ansible_param_bool(obj, 'gwredir', force=True, value='yes') - self._get_ansible_param_bool(obj, 'gwredir6', force=True, value='yes') - self._get_ansible_param_bool(obj, 'compression_push', force=True, value='yes', value_false='') - self._get_ansible_param_bool(obj, 'passtos', force=True, value='yes', value_false='') - self._get_ansible_param_bool(obj, 'client2client', force=True, value='yes', value_false='') - self._get_ansible_param_bool(obj, 'dynamic_ip', force=True, value='yes', value_false='') - self._get_ansible_param_bool(obj, 'push_register_dns') - self._get_ansible_param_bool(obj, 'username_as_common_name', force=True, value='enabled', value_false='disabled') - obj['digest'] = self._get_digest_name(self.params['digest']) - obj['tunnel_network'] = self.params['tunnel_network'] - obj['tunnel_networkv6'] = self.params['tunnel_networkv6'] - obj['local_network'] = self.params['local_network'] - obj['local_networkv6'] = self.params['local_networkv6'] - obj['remote_network'] = self.params['remote_network'] - obj['remote_networkv6'] = self.params['remote_networkv6'] - obj['allow_compression'] = self.params['allow_compression'] - obj['compression'] = self.params['compression'] - obj['topology'] = self.params['topology'] - self._get_ansible_param(obj, 'inactive_seconds') - self._get_ansible_param(obj, 'keepalive_interval') - self._get_ansible_param(obj, 'keepalive_timeout') - obj['exit_notify'] = self.params['exit_notify'] - obj['create_gw'] = self.params['create_gw'] - - if 'user' in self.params['mode']: - obj['authmode'] = ",".join(self.params['authmode']) - - if 'tls' in self.params['mode']: + obj["description"] = self.params["name"] + if self.params["state"] == "present": + obj["custom_options"] = self.params["custom_options"] + self._get_ansible_param_bool(obj, "disable") + self._get_ansible_param_bool(obj, "strictusercn") + self._get_ansible_param_bool(obj, "remote_cert_tls") + obj["mode"] = self.params["mode"] + obj["dev_mode"] = self.params["dev_mode"] + obj["interface"] = self.params["interface"] + obj["protocol"] = self.params["protocol"] + obj["local_port"] = str(self.params["local_port"]) + self._get_ansible_param(obj, "maxclients") + obj["verbosity_level"] = str(self.params["verbosity_level"]) + obj["data_ciphers_fallback"] = self.params["data_ciphers_fallback"] + obj["data_ciphers"] = ",".join(self.params["data_ciphers"]) + self._get_ansible_param_bool( + obj, "ncp_enable", force=True, value="enabled", value_false="disabled" + ) + self._get_ansible_param_bool(obj, "gwredir", force=True, value="yes") + self._get_ansible_param_bool(obj, "gwredir6", force=True, value="yes") + self._get_ansible_param_bool( + obj, "compression_push", force=True, value="yes", value_false="" + ) + self._get_ansible_param_bool( + obj, "passtos", force=True, value="yes", value_false="" + ) + self._get_ansible_param_bool( + obj, "client2client", force=True, value="yes", value_false="" + ) + self._get_ansible_param_bool( + obj, "dynamic_ip", force=True, value="yes", value_false="" + ) + self._get_ansible_param_bool(obj, "push_register_dns") + self._get_ansible_param_bool( + obj, + "username_as_common_name", + force=True, + value="enabled", + value_false="disabled", + ) + obj["digest"] = self._get_digest_name(self.params["digest"]) + obj["tunnel_network"] = self.params["tunnel_network"] + obj["tunnel_networkv6"] = self.params["tunnel_networkv6"] + obj["local_network"] = self.params["local_network"] + obj["local_networkv6"] = self.params["local_networkv6"] + obj["remote_network"] = self.params["remote_network"] + obj["remote_networkv6"] = self.params["remote_networkv6"] + obj["allow_compression"] = self.params["allow_compression"] + obj["compression"] = self.params["compression"] + obj["topology"] = self.params["topology"] + self._get_ansible_param(obj, "inactive_seconds") + self._get_ansible_param(obj, "keepalive_interval") + self._get_ansible_param(obj, "keepalive_timeout") + obj["exit_notify"] = self.params["exit_notify"] + obj["create_gw"] = self.params["create_gw"] + + if "user" in self.params["mode"]: + obj["authmode"] = ",".join(self.params["authmode"]) + + if "tls" in self.params["mode"]: # Find the caref id for the named CA if self.params is not None: - ca_elt = self.pfsense.find_ca_elt(self.params['ca']) + ca_elt = self.pfsense.find_ca_elt(self.params["ca"]) if ca_elt is None: - self.module.fail_json(msg='{0} is not a valid certificate authority'.format(self.params['ca'])) - obj['caref'] = ca_elt.find('refid').text + self.module.fail_json( + msg="{0} is not a valid certificate authority".format( + self.params["ca"] + ) + ) + obj["caref"] = ca_elt.find("refid").text # Find the crlref id for the named CRL if any - if self.params['crl'] is not None: - crl_elt = self.pfsense.find_crl_elt(self.params['crl']) + if self.params["crl"] is not None: + crl_elt = self.pfsense.find_crl_elt(self.params["crl"]) if crl_elt is None: - self.module.fail_json(msg='{0} is not a valid certificate revocation list'.format(self.params['crl'])) - obj['crlref'] = crl_elt.find('refid').text + self.module.fail_json( + msg="{0} is not a valid certificate revocation list".format( + self.params["crl"] + ) + ) + obj["crlref"] = crl_elt.find("refid").text else: - obj['crlref'] = '' + obj["crlref"] = "" # Find the certref id for the named certificate if any - if self.params['cert'] is not None: - cert_elt = self.pfsense.find_cert_elt(self.params['cert']) + if self.params["cert"] is not None: + cert_elt = self.pfsense.find_cert_elt(self.params["cert"]) if cert_elt is None: - self.module.fail_json(msg='{0} is not a valid certificate'.format(self.params['cert'])) - obj['certref'] = cert_elt.find('refid').text - - obj['cert_depth'] = str(self.params['cert_depth']) - obj['dh_length'] = str(self.params['dh_length']) - obj['ecdh_curve'] = self.params['ecdh_curve'] - self._get_ansible_param(obj, 'tls') - - if self.params['tls'] is not None: - obj['tls'] = self.params['tls'] - obj['tls_type'] = self.params['tls_type'] - - if 'server' in self.params['mode']: - obj['dns_domain'] = self.params['dns_domain'] - obj['dns_server1'] = self.params['dns_server1'] - obj['dns_server2'] = self.params['dns_server2'] - obj['dns_server3'] = self.params['dns_server3'] - obj['dns_server4'] = self.params['dns_server4'] - - if self.params['mode'] == 'p2p_shared_key': - obj['shared_key'] = self.params['shared_key'] + self.module.fail_json( + msg="{0} is not a valid certificate".format( + self.params["cert"] + ) + ) + obj["certref"] = cert_elt.find("refid").text + + obj["cert_depth"] = str(self.params["cert_depth"]) + obj["dh_length"] = str(self.params["dh_length"]) + obj["ecdh_curve"] = self.params["ecdh_curve"] + self._get_ansible_param(obj, "tls") + + if self.params["tls"] is not None: + obj["tls"] = self.params["tls"] + obj["tls_type"] = self.params["tls_type"] + + if "server" in self.params["mode"]: + obj["dns_domain"] = self.params["dns_domain"] + obj["dns_server1"] = self.params["dns_server1"] + obj["dns_server2"] = self.params["dns_server2"] + obj["dns_server3"] = self.params["dns_server3"] + obj["dns_server4"] = self.params["dns_server4"] + + if self.params["mode"] == "p2p_shared_key": + obj["shared_key"] = self.params["shared_key"] return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check name - self.pfsense.validate_string(params['name'], 'openvpn') + self.pfsense.validate_string(params["name"], "openvpn") - if params['state'] == 'absent': + if params["state"] == "absent": return True # tls is not valid for p2p_shared_key - if params['mode'] == 'p2p_shared_key' and params['tls'] is not None: - self.module.fail_json(msg='tls parameter is not valied with p2p_shared_key mode.') + if params["mode"] == "p2p_shared_key" and params["tls"] is not None: + self.module.fail_json( + msg="tls parameter is not valied with p2p_shared_key mode." + ) # check tunnel_networks - can be network alias or non-strict IP CIDR network - self.pfsense.validate_openvpn_tunnel_network(params.get('tunnel_network'), 'ipv4') - self.pfsense.validate_openvpn_tunnel_network(params.get('tunnel_network6'), 'ipv6') + self.pfsense.validate_openvpn_tunnel_network( + params.get("tunnel_network"), "ipv4" + ) + self.pfsense.validate_openvpn_tunnel_network( + params.get("tunnel_network6"), "ipv6" + ) # Check auth servers - if len(params['authmode']) > 0: - system = self.pfsense.get_element('system') - for authsrv in params['authmode']: - if authsrv != 'Local Database' and len(system.findall("authserver[name='{0}']".format(authsrv))) == 0: - self.module.fail_json(msg='Cannot find authentication server {0}.'.format(authsrv)) + if len(params["authmode"]) > 0: + system = self.pfsense.get_element("system") + for authsrv in params["authmode"]: + if ( + authsrv != "Local Database" + and len(system.findall("authserver[name='{0}']".format(authsrv))) + == 0 + ): + self.module.fail_json( + msg="Cannot find authentication server {0}.".format(authsrv) + ) # validate key - for param in ['shared_key', 'tls']: + for param in ["shared_key", "tls"]: if params[param] is not None: key = params[param] - if key == 'generate': + if key == "generate": # generate during _find_target (after _params_to_obj) - for just generate if not exists pass - elif re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL): + elif re.search( + "^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$", + key, + flags=re.MULTILINE | re.DOTALL, + ): params[param] = base64.b64encode(key.encode()).decode() else: key_decoded = base64.b64decode(key.encode()).decode() - if not re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', - key_decoded, flags=re.MULTILINE | re.DOTALL): - self.module.fail_json(msg='Could not recognize {0} key format: {1}'.format(param, key_decoded)) + if not re.search( + "^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$", + key_decoded, + flags=re.MULTILINE | re.DOTALL, + ): + self.module.fail_json( + msg="Could not recognize {0} key format: {1}".format( + param, key_decoded + ) + ) def _openvpn_port_used(self, protocol, interface, port, vpnid=0): - for elt in self.root_elt.findall('*[local_port]'): - if (elt.find('disable')): + for elt in self.root_elt.findall("*[local_port]"): + if elt.find("disable"): continue - this_vpnid = int(elt.find('vpnid').text) - if (this_vpnid == int(vpnid)): + this_vpnid = int(elt.find("vpnid").text) + if this_vpnid == int(vpnid): continue - this_interface = elt.find('interface').text - this_protocol = elt.find('protocol').text + this_interface = elt.find("interface").text + this_protocol = elt.find("protocol").text # (TCP|UDP)(4|6) does not conflict unless interface is any - if ((this_interface != "any" and interface != "any") and (len(protocol) == 4) and - (len(this_protocol) == 4) and (this_protocol[0:3] == protocol[0:3]) and (this_protocol[3] != protocol[3])): + if ( + (this_interface != "any" and interface != "any") + and (len(protocol) == 4) + and (len(this_protocol) == 4) + and (this_protocol[0:3] == protocol[0:3]) + and (this_protocol[3] != protocol[3]) + ): continue - this_port_text = elt.find('local_port').text + this_port_text = elt.find("local_port").text if this_port_text is None: continue this_port = int(this_port_text) - if (this_port == port and (this_protocol[0:3] == protocol[0:3]) and - (this_interface == interface or this_interface == "any" or interface == "any")): - self.module.fail_json(msg='The specified local_port ({0}) is in use by vpn ID {1}'.format(port, this_vpnid)) + if ( + this_port == port + and (this_protocol[0:3] == protocol[0:3]) + and ( + this_interface == interface + or this_interface == "any" + or interface == "any" + ) + ): + self.module.fail_json( + msg="The specified local_port ({0}) is in use by vpn ID {1}".format( + port, this_vpnid + ) + ) def _nextvpnid(self): - """ find next available vpnid """ + """find next available vpnid""" vpnid = 1 while len(self.root_elt.findall("*[vpnid='{0}']".format(vpnid))) != 0: vpnid += 1 @@ -290,10 +408,10 @@ def _nextvpnid(self): ############################## # XML processing # - def _find_openvpn_server(self, value, field='description'): - """ return openvpn-server element """ + def _find_openvpn_server(self, value, field="description"): + """return openvpn-server element""" i = 0 - for elt in self.root_elt.findall('openvpn-server'): + for elt in self.root_elt.findall("openvpn-server"): field_elt = elt.find(field) if field_elt is not None and field_elt.text == value: return (elt, i) @@ -302,55 +420,75 @@ def _find_openvpn_server(self, value, field='description'): def _find_last_openvpn_idx(self): i = 0 - for elt in self.root_elt.findall('openvpn-server'): + for elt in self.root_elt.findall("openvpn-server"): i += 1 return i def _get_params_to_remove(self): - """ returns the list of params to remove if they are not set """ + """returns the list of params to remove if they are not set""" params_to_remove = [] - for param in ['disable', 'strictusercn', 'push_register_dns', 'remote_cert_tls']: + for param in [ + "disable", + "strictusercn", + "push_register_dns", + "remote_cert_tls", + ]: if not self.params[param]: params_to_remove.append(param) return params_to_remove def _copy_and_update_target(self): - """ update the XML target_elt """ - (before, changed) = super(PFSenseOpenVPNServerModule, self)._copy_and_update_target() + """update the XML target_elt""" + (before, changed) = super( + PFSenseOpenVPNServerModule, self + )._copy_and_update_target() # Check if local port is used - self._openvpn_port_used(self.params['protocol'], self.params['interface'], self.params['local_port'], before['vpnid']) + self._openvpn_port_used( + self.params["protocol"], + self.params["interface"], + self.params["local_port"], + before["vpnid"], + ) if not changed: - self.diff['after'] = self.obj + self.diff["after"] = self.obj - self.result['vpnid'] = int(before['vpnid']) + self.result["vpnid"] = int(before["vpnid"]) return (before, changed) def _create_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" # Check if local port is used - self._openvpn_port_used(self.params['protocol'], self.params['interface'], self.params['local_port']) - target_elt = self.pfsense.new_element('openvpn-server') - self.obj['vpnid'] = self._nextvpnid() - self.result['vpnid'] = int(self.obj['vpnid']) - self.diff['before'] = '' - self.diff['after'] = self.obj - self.result['changed'] = True + self._openvpn_port_used( + self.params["protocol"], self.params["interface"], self.params["local_port"] + ) + target_elt = self.pfsense.new_element("openvpn-server") + self.obj["vpnid"] = self._nextvpnid() + self.result["vpnid"] = int(self.obj["vpnid"]) + self.diff["before"] = "" + self.diff["after"] = self.obj + self.result["changed"] = True self.idx = self._find_last_openvpn_idx() return target_elt def _find_target(self): - """ find the XML target_elt """ - (target_elt, self.idx) = self._find_openvpn_server(self.obj['description']) - for param in ['shared_key', 'tls']: + """find the XML target_elt""" + (target_elt, self.idx) = self._find_openvpn_server(self.obj["description"]) + for param in ["shared_key", "tls"]: current_elt = self.pfsense.get_element(param, target_elt) - if self.params[param] == 'generate': + if self.params[param] == "generate": if current_elt is None: - (dummy, key, stderr) = self.module.run_command('/usr/local/sbin/openvpn --genkey secret /dev/stdout') + (dummy, key, stderr) = self.module.run_command( + "/usr/local/sbin/openvpn --genkey secret /dev/stdout" + ) if stderr != "": - self.module.fail_json(msg='generate for "{0}" secret key: {1}'.format(param, stderr)) + self.module.fail_json( + msg='generate for "{0}" secret key: {1}'.format( + param, stderr + ) + ) self.obj[param] = base64.b64encode(key.encode()).decode() self.result[param] = self.obj[param] else: @@ -361,20 +499,34 @@ def _find_target(self): # run # def _pre_remove_target_elt(self): - """ processing before removing elt """ - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) - - if len(self.pfsense.interfaces.findall("*[if='ovpns{0}']".format(self.diff['before']['vpnid']))) > 0: - self.module.fail_json(msg='Cannot delete the OpenVPN instance while the interface ovpns{0} is assigned. Remove the interface assignment first.' - .format(self.diff['before']['vpnid'])) - - self.result['vpnid'] = int(self.diff['before']['vpnid']) - self.command_output = self.pfsense.phpshell(OPENVPN_SERVER_PHP_COMMAND_DEL.format(idx=self.idx)) + """processing before removing elt""" + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) + + if ( + len( + self.pfsense.interfaces.findall( + "*[if='ovpns{0}']".format(self.diff["before"]["vpnid"]) + ) + ) + > 0 + ): + self.module.fail_json( + msg="Cannot delete the OpenVPN instance while the interface ovpns{0} is assigned. Remove the interface assignment first.".format( + self.diff["before"]["vpnid"] + ) + ) + + self.result["vpnid"] = int(self.diff["before"]["vpnid"]) + self.command_output = self.pfsense.phpshell( + OPENVPN_SERVER_PHP_COMMAND_DEL.format(idx=self.idx) + ) def _update(self): - """ make the target pfsense reload """ - if self.params['state'] == 'present': - return self.pfsense.phpshell(OPENVPN_SERVER_PHP_COMMAND_SET.format(idx=self.idx)) + """make the target pfsense reload""" + if self.params["state"] == "present": + return self.pfsense.phpshell( + OPENVPN_SERVER_PHP_COMMAND_SET.format(idx=self.idx) + ) else: return self.command_output @@ -382,14 +534,16 @@ def _update(self): # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'" + self.obj['description'] + "'" + """return obj's name""" + return "'" + self.obj["description"] + "'" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'description') + values += self.format_cli_field(self.obj, "description") else: - values += self.format_updated_cli_field(self.obj, before, 'description', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "description", add_comma=(values) + ) return values diff --git a/plugins/module_utils/pfsense.py b/plugins/module_utils/pfsense.py index 339d7f4f..1132ba41 100644 --- a/plugins/module_utils/pfsense.py +++ b/plugins/module_utils/pfsense.py @@ -3,10 +3,12 @@ # Copyright: (c) 2018, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import sys + if sys.version_info >= (3, 4): import html import json @@ -24,13 +26,13 @@ def xml_find(node, elt): res = node.find(elt) if res is None: - res = ET.Element('') - res.text = '' + res = ET.Element("") + res.text = "" return res class PFSenseModule(object): - """ class managing pfsense base configuration """ + """class managing pfsense base configuration""" from ansible_collections.pfsensible.core.plugins.module_utils.__impl.interfaces import ( get_interface_display_name, @@ -63,24 +65,24 @@ class PFSenseModule(object): validate_openvpn_tunnel_network, ) - def __init__(self, module, config='/cf/conf/config.xml'): + def __init__(self, module, config="/cf/conf/config.xml"): self.module = module self.config = config self.tree = ET.parse(config) self.root = self.tree.getroot() - self.config_version = float(self.get_element('version').text) - self.aliases = self.get_element('aliases', create_node=True) - self.interfaces = self.get_element('interfaces') - self.ifgroups = self.get_element('ifgroups') - self.rules = self.get_element('filter') - self.shapers = self.get_element('shaper') - self.dnshapers = self.get_element('dnshaper') - self.vlans = self.get_element('vlans') - self.gateways = self.get_element('gateways') - self.ipsec = self.get_element('ipsec') - self.openvpn = self.get_element('openvpn') - self.virtualip = self.get_element('virtualip') - self.debug = open('/tmp/pfsense.debug', 'w') + self.config_version = float(self.get_element("version").text) + self.aliases = self.get_element("aliases", create_node=True) + self.interfaces = self.get_element("interfaces") + self.ifgroups = self.get_element("ifgroups") + self.rules = self.get_element("filter") + self.shapers = self.get_element("shaper") + self.dnshapers = self.get_element("dnshaper") + self.vlans = self.get_element("vlans") + self.gateways = self.get_element("gateways") + self.ipsec = self.get_element("ipsec") + self.openvpn = self.get_element("openvpn") + self.virtualip = self.get_element("virtualip") + self.debug = open("/tmp/pfsense.debug", "w") if sys.version_info >= (3, 4): self._scrub() @@ -95,34 +97,34 @@ def _scrub(self): @staticmethod def addr_normalize(addr): - """ return address element formatted like module argument """ - address = '' - ports = '' - if 'address' in addr: - address = addr['address'] - if 'any' in addr: - address = 'any' - if 'network' in addr: - address = 'NET:%s' % addr['network'] - if address == '': - raise ValueError('UNKNOWN addr %s' % addr) - if 'port' in addr: - ports = addr['port'] - if 'not' in addr: - address = '!' + address + """return address element formatted like module argument""" + address = "" + ports = "" + if "address" in addr: + address = addr["address"] + if "any" in addr: + address = "any" + if "network" in addr: + address = "NET:%s" % addr["network"] + if address == "": + raise ValueError("UNKNOWN addr %s" % addr) + if "port" in addr: + ports = addr["port"] + if "not" in addr: + address = "!" + address return address, ports @staticmethod - def new_element(tag, text='\n\t\t\t'): - """ Create and return new XML configuration element """ + def new_element(tag, text="\n\t\t\t"): + """Create and return new XML configuration element""" elt = ET.Element(tag) # Attempt to preserve some of the formatting of pfSense's config.xml elt.text = text - elt.tail = '\n\t\t' + elt.tail = "\n\t\t" return elt def get_element(self, node, root_elt=None, create_node=False): - """ return configuration element """ + """return configuration element""" if root_elt is None: root_elt = self.root elt = root_elt.find(node) @@ -132,24 +134,26 @@ def get_element(self, node, root_elt=None, create_node=False): return elt def get_elements(self, node, root_elt=None): - """ return all configuration elements """ + """return all configuration elements""" if root_elt is None: root_elt = self.root return root_elt.findall(node) def get_index(self, elt, root_elt=None): - """ Get elt index """ + """Get elt index""" if root_elt is None: root_elt = self.root return list(root_elt).index(elt) - def find_elt(self, node, search_text, search_field='descr', root_elt=None, multiple_ok=False): - """ return object elt if found """ + def find_elt( + self, node, search_text, search_field="descr", root_elt=None, multiple_ok=False + ): + """return object elt if found""" search_xpath = "{0}[{1}='{2}']".format(node, search_field, search_text) return self.find_elt_xpath(search_xpath, root_elt, multiple_ok) def find_elt_xpath(self, search_xpath, root_elt=None, multiple_ok=False): - """ return object elt if found """ + """return object elt if found""" if root_elt is None: root_elt = self.root result = root_elt.findall(search_xpath) @@ -160,12 +164,14 @@ def find_elt_xpath(self, search_xpath, root_elt=None, multiple_ok=False): if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple elements for name {0}.'.format(self.obj['name'])) + self.module.fail_json( + msg="Found multiple elements for name {0}.".format(self.obj["name"]) + ) return None @staticmethod def remove_deleted_param_from_elt(elt, param, params): - """ Remove from a deleted param from an xml elt """ + """Remove from a deleted param from an xml elt""" changed = False if param not in params: param_elt = elt.find(param) @@ -175,29 +181,29 @@ def remove_deleted_param_from_elt(elt, param, params): return changed def is_ipsec_enabled(self): - """ return True if ipsec is enabled """ + """return True if ipsec is enabled""" if self.ipsec is None: return False for elt in self.ipsec: - if elt.tag == 'phase1' and elt.find('disabled') is None: + if elt.tag == "phase1" and elt.find("disabled") is None: return True return False def is_openvpn_enabled(self): - """ return True if openvpn is enabled """ + """return True if openvpn is enabled""" if self.openvpn is None: return False for elt in self.openvpn: - if elt.tag == 'openvpn-server' or elt.tag == 'openvpn-client': + if elt.tag == "openvpn-server" or elt.tag == "openvpn-client": return True return False - def find_ipsec_phase1(self, field_value, field='descr'): - """ return ipsec phase1 elt if found """ + def find_ipsec_phase1(self, field_value, field="descr"): + """return ipsec phase1 elt if found""" for ipsec_elt in self.ipsec: - if ipsec_elt.tag != 'phase1': + if ipsec_elt.tag != "phase1": continue field_elt = ipsec_elt.find(field) @@ -208,19 +214,22 @@ def find_ipsec_phase1(self, field_value, field='descr'): @staticmethod def rule_match_interface(rule_elt, interface, floating): - """ check if a rule elt match the targeted interface - floating rules must match the floating mode instead of the interface name + """check if a rule elt match the targeted interface + floating rules must match the floating mode instead of the interface name """ - interface_elt = rule_elt.find('interface') - floating_elt = rule_elt.find('floating') + interface_elt = rule_elt.find("interface") + floating_elt = rule_elt.find("floating") if floating_elt is not None: return floating elif floating: return False - return interface_elt is not None and interface_elt.text.lower() == interface.lower() + return ( + interface_elt is not None + and interface_elt.text.lower() == interface.lower() + ) def get_interface_rules_count(self, interface, floating): - """ get rules count in interface/floating """ + """get rules count in interface/floating""" count = 0 for rule_elt in self.rules: if not self.rule_match_interface(rule_elt, interface, floating): @@ -230,13 +239,13 @@ def get_interface_rules_count(self, interface, floating): return count def get_rule_position(self, descr, interface, floating, first=True): - """ get rule position in interface/floating """ + """get rule position in interface/floating""" i = 0 found = None for rule_elt in self.rules: if not self.rule_match_interface(rule_elt, interface, floating): continue - descr_elt = rule_elt.find('descr') + descr_elt = rule_elt.find("descr") if descr_elt is not None and descr_elt.text == descr: if first: return i @@ -247,34 +256,41 @@ def get_rule_position(self, descr, interface, floating, first=True): return found def copy_dict_to_element(self, src, top_elt, sub=0, prev_elt=None): - """ Copy/update top_elt from src """ + """Copy/update top_elt from src""" changed = False - for (key, value) in src.items(): + for key, value in src.items(): this_elt = top_elt.find(key) - self.debug.write('changed=%s key=%s value=%s this_elt=%s, sub=%d\n' % (changed, key, value, this_elt, sub)) + self.debug.write( + "changed=%s key=%s value=%s this_elt=%s, sub=%d\n" + % (changed, key, value, this_elt, sub) + ) if this_elt is None: if isinstance(value, dict): changed = True - self.debug.write('calling copy_dict_to_element()\n') + self.debug.write("calling copy_dict_to_element()\n") # Create a new element new_elt = ET.Element(key) - new_elt.text = '\n%s' % ('\t' * (sub + 4)) - new_elt.tail = '\n%s' % ('\t' * (sub + 2)) + new_elt.text = "\n%s" % ("\t" * (sub + 4)) + new_elt.tail = "\n%s" % ("\t" * (sub + 2)) if prev_elt is not None: - prev_elt.tail = '\n%s' % ('\t' * (sub + 2)) + prev_elt.tail = "\n%s" % ("\t" * (sub + 2)) prev_elt = new_elt - self.copy_dict_to_element(value, new_elt, sub=sub + 1, prev_elt=prev_elt) + self.copy_dict_to_element( + value, new_elt, sub=sub + 1, prev_elt=prev_elt + ) top_elt.append(new_elt) elif isinstance(value, list): if value: changed = True if prev_elt is not None: - prev_elt.tail = '\n%s' % ('\t' * (sub + 2)) + prev_elt.tail = "\n%s" % ("\t" * (sub + 2)) for item in value: new_elt = self.new_element(key) prev_elt = new_elt if isinstance(item, dict): - self.copy_dict_to_element(item, new_elt, sub=sub + 1, prev_elt=prev_elt) + self.copy_dict_to_element( + item, new_elt, sub=sub + 1, prev_elt=prev_elt + ) else: new_elt.text = item top_elt.append(new_elt) @@ -283,16 +299,21 @@ def copy_dict_to_element(self, src, top_elt, sub=0, prev_elt=None): # Create a new element new_elt = ET.Element(key) new_elt.text = value - new_elt.tail = '\n%s' % ('\t' * (sub + 2)) + new_elt.tail = "\n%s" % ("\t" * (sub + 2)) if prev_elt is not None: - prev_elt.tail = '\n%s' % ('\t' * (sub + 2)) + prev_elt.tail = "\n%s" % ("\t" * (sub + 2)) prev_elt = new_elt top_elt.append(new_elt) - self.debug.write('changed=%s added key=%s value=%s tag=%s\n' % (changed, key, value, top_elt.tag)) + self.debug.write( + "changed=%s added key=%s value=%s tag=%s\n" + % (changed, key, value, top_elt.tag) + ) else: if isinstance(value, dict): - self.debug.write('calling copy_dict_to_element()\n') - if self.copy_dict_to_element(value, this_elt, sub=sub + 1, prev_elt=this_elt): + self.debug.write("calling copy_dict_to_element()\n") + if self.copy_dict_to_element( + value, this_elt, sub=sub + 1, prev_elt=this_elt + ): changed = True elif isinstance(value, list): all_sub_elts = top_elt.findall(key) @@ -313,19 +334,24 @@ def copy_dict_to_element(self, src, top_elt, sub=0, prev_elt=None): # set all elts for idx, item in enumerate(value): if isinstance(item, str): - if all_sub_elts[idx].text is None and item == '': + if all_sub_elts[idx].text is None and item == "": pass elif all_sub_elts[idx].text != item: all_sub_elts[idx].text = item changed = True - elif self.copy_dict_to_element(item, all_sub_elts[idx], sub=sub + 1, prev_elt=prev_elt): + elif self.copy_dict_to_element( + item, all_sub_elts[idx], sub=sub + 1, prev_elt=prev_elt + ): changed = True - elif this_elt.text is None and value == '': + elif this_elt.text is None and value == "": pass elif this_elt.text != value: this_elt.text = value changed = True - self.debug.write('changed=%s this_elt.text=%s value=%s\n' % (changed, this_elt.text, value)) + self.debug.write( + "changed=%s this_elt.text=%s value=%s\n" + % (changed, this_elt.text, value) + ) prev_elt = this_elt # Sub-elements must be completely described, so remove any missing elements @@ -333,25 +359,27 @@ def copy_dict_to_element(self, src, top_elt, sub=0, prev_elt=None): for child_elt in list(top_elt): if child_elt.tag not in src: changed = True - self.debug.write('changed=%s removed tag=%s\n' % (changed, child_elt.tag)) + self.debug.write( + "changed=%s removed tag=%s\n" % (changed, child_elt.tag) + ) top_elt.remove(child_elt) if prev_elt is not None: - prev_elt.tail = '\n%s' % ('\t' * (sub + 1)) + prev_elt.tail = "\n%s" % ("\t" * (sub + 1)) self.debug.flush() return changed @staticmethod def array_to_php(src, php_name): - """ Generate PHP commands to initialiaze a variable with contents of an array """ + """Generate PHP commands to initialiaze a variable with contents of an array""" array_values = "'" + "','".join(src) + "'" cmd = f"${php_name} = array({array_values});\n" return cmd @staticmethod def dict_to_php(src, php_name): - """ Generate PHP commands to initialiaze a variable with contents of a dict """ + """Generate PHP commands to initialiaze a variable with contents of a dict""" cmd = "${0} = array();\n".format(php_name) for key, value in src.items(): if value is not None: @@ -362,13 +390,13 @@ def dict_to_php(src, php_name): @staticmethod def element_to_dict(src_elt): - """ Create dict from XML src_elt """ + """Create dict from XML src_elt""" res = {} for elt in src_elt: if len(elt) > 0: value = PFSenseModule.element_to_dict(elt) else: - value = elt.text if elt.text is not None else '' + value = elt.text if elt.text is not None else "" if elt.tag in res: if not isinstance(res[elt.tag], list): @@ -379,7 +407,7 @@ def element_to_dict(src_elt): return res def config_get_path(self, name, default=None): - """ get value of a specific configuration path """ + """get value of a specific configuration path""" elt = self.find_elt_xpath(name) if elt is not None: return elt.text @@ -387,71 +415,85 @@ def config_get_path(self, name, default=None): return default def get_refid(self, node, name): - """ get refid of name in specific nodes """ + """get refid of name in specific nodes""" elt = self.find_elt(node, name) if elt is not None: - return xml_find(elt, 'refid').text + return xml_find(elt, "refid").text else: return None def get_caref(self, name): - """ get CA refid for name """ + """get CA refid for name""" # global is a special case - if name == 'global': - return 'global' + if name == "global": + return "global" # Otherwise search the ca elements - return self.get_refid('ca', name) + return self.get_refid("ca", name) def get_certref(self, name): - """ get Cert refid for name """ - return self.get_refid('cert', name) + """get Cert refid for name""" + return self.get_refid("cert", name) def get_crlref(self, name): - """ get CRL refid for name """ - return self.get_refid('crl', name) + """get CRL refid for name""" + return self.get_refid("crl", name) @staticmethod def get_username(): - """ get username logged """ + """get username logged""" username = pwd.getpwuid(os.getuid()).pw_name - if os.environ.get('SUDO_USER'): - username = os.environ.get('SUDO_USER') + if os.environ.get("SUDO_USER"): + username = os.environ.get("SUDO_USER") # sudo masks this - sshclient = os.environ.get('SSH_CLIENT') + sshclient = os.environ.get("SSH_CLIENT") if sshclient: - username = username + '@' + sshclient + username = username + "@" + sshclient return username def find_alias(self, name, aliastype=None): - """ return alias named name, having type aliastype if specified """ + """return alias named name, having type aliastype if specified""" for alias in self.aliases: - if xml_find(alias, 'name').text == name and (aliastype is None or xml_find(alias, 'type').text == aliastype): + if xml_find(alias, "name").text == name and ( + aliastype is None or xml_find(alias, "type").text == aliastype + ): return alias return None def is_ip_or_alias(self, address): - """ return True if address is an ip or an alias """ + """return True if address is an ip or an alias""" # Is it an alias? - if (self.find_alias(address, 'host') is not None - or self.find_alias(address, 'network') is not None - or self.find_alias(address, 'urltable') is not None): + if ( + self.find_alias(address, "host") is not None + or self.find_alias(address, "network") is not None + or self.find_alias(address, "urltable") is not None + ): return True # Is it an IP address or network? - if self.is_ipv4_address(address) or self.is_ipv4_network(address) or self.is_ipv6_address(address) or self.is_ipv6_network(address): + if ( + self.is_ipv4_address(address) + or self.is_ipv4_network(address) + or self.is_ipv6_address(address) + or self.is_ipv6_network(address) + ): return True # None of the above return False def is_gateway_group(self, gwgroup): - """ return True if gwgroup is a gateway group """ - return self.find_elt_xpath(f"./gateways/gateway_group[name='{gwgroup}']") is not None + """return True if gwgroup is a gateway group""" + return ( + self.find_elt_xpath(f"./gateways/gateway_group[name='{gwgroup}']") + is not None + ) def is_port_or_alias(self, port): - """ return True if port is a valid port number or an alias """ - if (self.find_alias(port, 'port') is not None - or self.find_alias(port, 'urltable_ports') is not None): + """return True if port is a valid port number or an alias""" + if ( + self.find_alias(port, "port") is not None + or self.find_alias(port, "urltable_ports") is not None + ): return True try: if int(port) > 0 and int(port) < 65536: @@ -461,57 +503,57 @@ def is_port_or_alias(self, port): return False def is_virtual_ip(self, addr): - """ return True if addr is a virtual ip """ + """return True if addr is a virtual ip""" if self.virtualip is None: return False - if self.find_elt('vip', addr, 'subnet', root_elt=self.virtualip) is None: + if self.find_elt("vip", addr, "subnet", root_elt=self.virtualip) is None: return False return True def get_virtual_ip_interface(self, vip): - """ return interface name for virtual IP name or network """ + """return interface name for virtual IP name or network""" if self.virtualip is None: return None - vip_elt = self.find_elt('vip', vip, 'descr', root_elt=self.virtualip) + vip_elt = self.find_elt("vip", vip, "descr", root_elt=self.virtualip) if vip_elt is None: - vip_elt = self.find_elt('vip', vip, 'subnet', root_elt=self.virtualip) + vip_elt = self.find_elt("vip", vip, "subnet", root_elt=self.virtualip) if vip_elt is None: return None - uniqid_elt = vip_elt.find('uniqid') + uniqid_elt = vip_elt.find("uniqid") if uniqid_elt is None: return None - return "_vip" + xml_find(vip_elt, 'uniqid').text + return "_vip" + xml_find(vip_elt, "uniqid").text def find_queue(self, name, interface=None, enabled=False): - """ return QOS queue if found """ + """return QOS queue if found""" # iterate each interface for shaper_elt in self.shapers: if interface is not None: - interface_elt = shaper_elt.find('interface') + interface_elt = shaper_elt.find("interface") if interface_elt is None or interface_elt.text != interface: continue if enabled: - enabled_elt = shaper_elt.find('enabled') - if enabled_elt is None or enabled_elt.text != 'on': + enabled_elt = shaper_elt.find("enabled") + if enabled_elt is None or enabled_elt.text != "on": continue # iterate each queue - for queue_elt in shaper_elt.findall('.//queue'): - name_elt = queue_elt.find('name') + for queue_elt in shaper_elt.findall(".//queue"): + name_elt = queue_elt.find("name") if name_elt is None or name_elt.text != name: continue if enabled: - enabled_elt = queue_elt.find('enabled') - if enabled_elt is None or enabled_elt.text != 'on': + enabled_elt = queue_elt.find("enabled") + if enabled_elt is None or enabled_elt.text != "on": continue # found it @@ -520,16 +562,16 @@ def find_queue(self, name, interface=None, enabled=False): return None def find_limiter(self, name, enabled=False): - """ return QOS limiter if found """ + """return QOS limiter if found""" # iterate each queue for queue_elt in self.dnshapers: if enabled: - enabled_elt = queue_elt.find('enabled') - if enabled_elt is None or enabled_elt.text != 'on': + enabled_elt = queue_elt.find("enabled") + if enabled_elt is None or enabled_elt.text != "on": continue - name_elt = queue_elt.find('name') + name_elt = queue_elt.find("name") if name_elt is None or name_elt.text != name: continue @@ -538,86 +580,114 @@ def find_limiter(self, name, enabled=False): return None def find_vlan(self, interface, tag): - """ return vlan elt if found """ + """return vlan elt if found""" if self.vlans is None: - self.vlans = self.get_element('vlans') + self.vlans = self.get_element("vlans") if self.vlans is not None: for vlan in self.vlans: - if xml_find(vlan, 'if').text == interface and xml_find(vlan, 'tag').text == tag: + if ( + xml_find(vlan, "if").text == interface + and xml_find(vlan, "tag").text == tag + ): return vlan return None def _create_gw_elt(self, name, interface_id, protocol): - gw_elt = ET.Element('gateway_item') - gw_elt.append(self.new_element('interface', interface_id)) - gw_elt.append(self.new_element('gateway', 'dynamic')) - gw_elt.append(self.new_element('name', name)) - gw_elt.append(self.new_element('weight', '1')) - gw_elt.append(self.new_element('ipprotocol', protocol)) - gw_elt.append(self.new_element('descr', 'Interface ' + name + ' Gateway')) + gw_elt = ET.Element("gateway_item") + gw_elt.append(self.new_element("interface", interface_id)) + gw_elt.append(self.new_element("gateway", "dynamic")) + gw_elt.append(self.new_element("name", name)) + gw_elt.append(self.new_element("weight", "1")) + gw_elt.append(self.new_element("ipprotocol", protocol)) + gw_elt.append(self.new_element("descr", "Interface " + name + " Gateway")) return gw_elt - def find_gateway_elt(self, name, interface=None, protocol=None, dhcp=False, vti=False): - """ return gateway elt if found """ + def find_gateway_elt( + self, name, interface=None, protocol=None, dhcp=False, vti=False + ): + """return gateway elt if found""" for gw_elt in self.gateways: - if gw_elt.tag != 'gateway_item': + if gw_elt.tag != "gateway_item": continue - if protocol is not None and xml_find(gw_elt, 'ipprotocol').text != protocol: + if protocol is not None and xml_find(gw_elt, "ipprotocol").text != protocol: continue - if interface is not None and xml_find(gw_elt, 'interface').text != interface: + if ( + interface is not None + and xml_find(gw_elt, "interface").text != interface + ): continue - if xml_find(gw_elt, 'name').text == name: + if xml_find(gw_elt, "name").text == name: return gw_elt for interface_elt in self.interfaces: - descr_elt = interface_elt.find('descr') + descr_elt = interface_elt.find("descr") if descr_elt is None or descr_elt.text is None: continue - if_elt = interface_elt.find('if') + if_elt = interface_elt.find("if") if if_elt is None or if_elt.text is None: continue descr_text = descr_elt.text.strip().upper() # todo: implement interface match with ipsec tunnels threw vtimaps - if vti and (protocol is None or protocol == 'inet') and if_elt.text.startswith('ipsec') and descr_text + '_VTIV4' == name: - return self._create_gw_elt(name, interface_elt.tag, 'inet') - - if vti and (protocol is None or protocol == 'inet6') and if_elt.text.startswith('ipsec') and descr_text + '_VTIV6' == name: - return self._create_gw_elt(name, interface_elt.tag, 'inet6') + if ( + vti + and (protocol is None or protocol == "inet") + and if_elt.text.startswith("ipsec") + and descr_text + "_VTIV4" == name + ): + return self._create_gw_elt(name, interface_elt.tag, "inet") + + if ( + vti + and (protocol is None or protocol == "inet6") + and if_elt.text.startswith("ipsec") + and descr_text + "_VTIV6" == name + ): + return self._create_gw_elt(name, interface_elt.tag, "inet6") if dhcp: - ipaddr_elt = interface_elt.find('ipaddr') - if (protocol is None or protocol == 'inet') and ipaddr_elt is not None and ipaddr_elt.text == 'dhcp' and descr_text + "_DHCP" == name: - return self._create_gw_elt(name, interface_elt.tag, 'inet') - - ipaddr_elt = interface_elt.find('ipaddrv6') - if (protocol is None or protocol == 'inet6') and ipaddr_elt is not None and ipaddr_elt.text == 'dhcp6' and descr_text + "_DHCP6" == name: - return self._create_gw_elt(name, interface_elt.tag, 'inet6') + ipaddr_elt = interface_elt.find("ipaddr") + if ( + (protocol is None or protocol == "inet") + and ipaddr_elt is not None + and ipaddr_elt.text == "dhcp" + and descr_text + "_DHCP" == name + ): + return self._create_gw_elt(name, interface_elt.tag, "inet") + + ipaddr_elt = interface_elt.find("ipaddrv6") + if ( + (protocol is None or protocol == "inet6") + and ipaddr_elt is not None + and ipaddr_elt.text == "dhcp6" + and descr_text + "_DHCP6" == name + ): + return self._create_gw_elt(name, interface_elt.tag, "inet6") return None - def find_gateway_group_elt(self, name, protocol='inet'): - """ return gateway_group elt if found """ + def find_gateway_group_elt(self, name, protocol="inet"): + """return gateway_group elt if found""" for gw_grp_elt in self.gateways: - if gw_grp_elt.tag != 'gateway_group': + if gw_grp_elt.tag != "gateway_group": continue - if xml_find(gw_grp_elt, 'name').text != name: + if xml_find(gw_grp_elt, "name").text != name: continue # check if protocol match match_protocol = True for gw_elt in gw_grp_elt: - if gw_elt.tag != 'item' or gw_elt.text is None: + if gw_elt.tag != "item" or gw_elt.text is None: continue - items = gw_elt.text.split('|') + items = gw_elt.text.split("|") if not items or self.find_gateway_elt(items[0], None, protocol) is None: match_protocol = False break @@ -630,7 +700,7 @@ def find_gateway_group_elt(self, name, protocol='inet'): return None def find_active_gateways(self): - """ returns list of active gateways """ + """returns list of active gateways""" (retcode, raw_output, error) = self.phpshell("playback gatewaystatus") write = False @@ -651,53 +721,63 @@ def find_active_gateways(self): for item in line.split(): dline[head[c]] = item c += 1 - if dline is not {}: + if dline != {}: data.append(dline) return data - def find_ca_elt(self, ca, search_field='descr'): - """ return certificate authority elt if found """ - return self.find_elt('ca', ca, search_field) + def find_ca_elt(self, ca, search_field="descr"): + """return certificate authority elt if found""" + return self.find_elt("ca", ca, search_field) - def find_cert_elt(self, cert, search_field='descr'): - """ return certificate elt if found """ - return self.find_elt('cert', cert, search_field) + def find_cert_elt(self, cert, search_field="descr"): + """return certificate elt if found""" + return self.find_elt("cert", cert, search_field) - def find_crl_elt(self, crl, search_field='descr'): - """ return certificate revocation list elt if found """ - return self.find_elt('crl', crl, search_field) + def find_crl_elt(self, crl, search_field="descr"): + """return certificate revocation list elt if found""" + return self.find_elt("crl", crl, search_field) def find_schedule_elt(self, name): - """ return schedule elt if found """ + """return schedule elt if found""" return self.find_elt_xpath("./schedules/schedule[name='{0}']".format(name)) @staticmethod - def uniqid(prefix='', more_entropy=False): - """ return an identifier based on time """ + def uniqid(prefix="", more_entropy=False): + """return an identifier based on time""" if more_entropy: - return prefix + '{0:x}{1:05x}{2:.8F}'.format(int(time.time()), int(time.time() * 1000000) % 0x100000, random.random() * 10) + return prefix + "{0:x}{1:05x}{2:.8F}".format( + int(time.time()), + int(time.time() * 1000000) % 0x100000, + random.random() * 10, + ) - return prefix + '{0:x}{1:05x}'.format(int(time.time()), int(time.time() * 1000000) % 0x100000) + return prefix + "{0:x}{1:05x}".format( + int(time.time()), int(time.time() * 1000000) % 0x100000 + ) def phpshell(self, command, debug=True): - """ Run a command in the php developer shell """ + """Run a command in the php developer shell""" phpshell = "global $config;\n" if debug: phpshell = "global $debug;\n$debug = 1;\n" phpshell += command + "\nexec\nexit" # Dummy argument suppresses displaying help message - return self.module.run_command('/usr/local/sbin/pfSsh.php dummy', data=phpshell) + return self.module.run_command("/usr/local/sbin/pfSsh.php dummy", data=phpshell) def php(self, command): - """ Run a command in php and return the output """ - cmd = '\n' - (dummy, stdout, stderr) = self.module.run_command('/usr/local/bin/php', data=cmd) + cmd += "\n?>\n" + (dummy, stdout, stderr) = self.module.run_command( + "/usr/local/bin/php", data=cmd + ) # If /var/run/booting is in place, various requires will emit a "." - (stdout, nsubs) = re.subn(r'^\.+', '', stdout) + (stdout, nsubs) = re.subn(r"^\.+", "", stdout) if nsubs > 0: - self.module.warn('/var/run/booting appears to be present, confirm successful boot and remove if appropriate.') + self.module.warn( + "/var/run/booting appears to be present, confirm successful boot and remove if appropriate." + ) # TODO: check stderr for errors try: result = json.loads(stdout) @@ -705,27 +785,29 @@ def php(self, command): self.module.fail_json(msg=f"{e}", cmd=cmd, stdout=stdout, stderr=stderr) return result - def write_config(self, descr='Updated by ansible pfsense module'): - """ Generate config file """ - revision = self.get_element('revision') - xml_find(revision, 'time').text = '%d' % time.time() - revdescr = revision.find('description') + def write_config(self, descr="Updated by ansible pfsense module"): + """Generate config file""" + revision = self.get_element("revision") + xml_find(revision, "time").text = "%d" % time.time() + revdescr = revision.find("description") if revdescr is None: - revdescr = ET.Element('description') + revdescr = ET.Element("description") revision.append(revdescr) revdescr.text = descr username = self.get_username() - xml_find(revision, 'username').text = username + xml_find(revision, "username").text = username (tmp_handle, tmp_name) = mkstemp() os.close(tmp_handle) if sys.version_info >= (3, 4): - self.tree.write(tmp_name, xml_declaration=True, method='xml', short_empty_elements=False) + self.tree.write( + tmp_name, xml_declaration=True, method="xml", short_empty_elements=False + ) else: - self.tree.write(tmp_name, xml_declaration=True, method='xml') + self.tree.write(tmp_name, xml_declaration=True, method="xml") shutil.move(tmp_name, self.config) os.chmod(self.config, 0o644) try: - os.remove('/tmp/config.cache') + os.remove("/tmp/config.cache") except OSError as exception: if exception.errno == 2: # suppress "No such file or directory error @@ -735,7 +817,7 @@ def write_config(self, descr='Updated by ansible pfsense module'): @staticmethod def get_version(): - """ get pfSense version """ + """get pfSense version""" # TODO: use subprocess when we'll drop support for python 2.7 os.system("pkg-static info | grep pfSense-base > /tmp/pfVersion") vfile = open("/tmp/pfVersion", "r") @@ -745,17 +827,21 @@ def get_version(): @staticmethod def is_ce_version(version): - """ return True if version is a CE version (for now, we only have 2.x patterns) """ + """return True if version is a CE version (for now, we only have 2.x patterns)""" return version[0] == 2 def is_version(self, version, or_more=True): - """ check target pfSense version """ + """check target pfSense version""" if self.pfsense_version is None: pfsense_version = self.get_version() self.pfsense_version = [] - match = re.match(r'(\d+)\.(\d+)\.?(\d+)?', pfsense_version) + match = re.match(r"(\d+)\.(\d+)\.?(\d+)?", pfsense_version) if match is None: - self.module.fail_json(msg="Unable to get version from pfSense (got '{0}')".format(pfsense_version)) + self.module.fail_json( + msg="Unable to get version from pfSense (got '{0}')".format( + pfsense_version + ) + ) for idx in range(0, match.lastindex): self.pfsense_version.append(int(match.group(idx + 1))) @@ -771,21 +857,25 @@ def is_version(self, version, or_more=True): if self.pfsense_version[idx] > ver and or_more: return True - if ver < self.pfsense_version[idx] and not or_more or ver > self.pfsense_version[idx]: + if ( + ver < self.pfsense_version[idx] + and not or_more + or ver > self.pfsense_version[idx] + ): return False return True def is_at_least_2_5_2(self): - """ check target pfSense version """ + """check target pfSense version""" return self.is_version([2, 5, 2]) or self.is_version([21, 5]) def is_at_least_2_5_0(self): - """ check target pfSense version """ + """check target pfSense version""" return self.is_version([2, 5, 0]) or self.is_version([21, 2]) def apply_ipsec_changes(self): - """ execute pfSense code to appy ipsec changes """ + """execute pfSense code to appy ipsec changes""" if self.is_at_least_2_5_0(): return self.phpshell( "require_once('vpn.inc');" diff --git a/plugins/module_utils/route.py b/plugins/module_utils/route.py index a73b082b..24099f8a 100644 --- a/plugins/module_utils/route.py +++ b/plugins/module_utils/route.py @@ -4,15 +4,18 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) ROUTE_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - descr=dict(required=True, type='str'), - gateway=dict(required=False, type='str'), - network=dict(required=False, type='str'), - disabled=dict(default=False, type='bool'), + state=dict(default="present", choices=["present", "absent"]), + descr=dict(required=True, type="str"), + gateway=dict(required=False, type="str"), + network=dict(required=False, type="str"), + disabled=dict(default=False, type="bool"), ) ROUTE_REQUIRED_IF = [ @@ -21,11 +24,11 @@ class PFSenseRouteModule(PFSenseModuleBase): - """ module managing pfsense routes """ + """module managing pfsense routes""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return ROUTE_ARGUMENT_SPEC ############################## @@ -34,19 +37,19 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseRouteModule, self).__init__(module, pfsense) self.name = "pfsense_route" - self.root_elt = self.pfsense.get_element('staticroutes') + self.root_elt = self.pfsense.get_element("staticroutes") self.obj = dict() self.route_cmd = list() if self.root_elt is None: - self.root_elt = self.pfsense.new_element('staticroutes') + self.root_elt = self.pfsense.new_element("staticroutes") self.pfsense.root.append(self.root_elt) ############################## # params processing # def _expand_alias(self, networks): - """ return real addresses of alias """ + """return real addresses of alias""" ret = list() while len(networks) > 0: @@ -54,94 +57,113 @@ def _expand_alias(self, networks): if self.pfsense.is_ipv4_network(alias, strict=False): ret.append(alias) else: - alias_elt = self.pfsense.find_alias(alias, aliastype='host') + alias_elt = self.pfsense.find_alias(alias, aliastype="host") if alias_elt is None: - alias_elt = self.pfsense.find_alias(alias, aliastype='network') - networks += alias_elt.find('address').text.split(' ') + alias_elt = self.pfsense.find_alias(alias, aliastype="network") + networks += alias_elt.find("address").text.split(" ") return ret def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() self.obj = obj - obj['descr'] = params['descr'] + obj["descr"] = params["descr"] target_elt = self._find_target() - if params['state'] == 'present': - self._get_ansible_param(obj, 'gateway') - self._get_ansible_param(obj, 'descr') - self._get_ansible_param(obj, 'network') + if params["state"] == "present": + self._get_ansible_param(obj, "gateway") + self._get_ansible_param(obj, "descr") + self._get_ansible_param(obj, "network") - self._get_ansible_param_bool(obj, 'disabled') + self._get_ansible_param_bool(obj, "disabled") if target_elt is not None: - old_network = target_elt.find('network').text - if params['state'] == 'absent' or old_network != params['network']: + old_network = target_elt.find("network").text + if params["state"] == "absent" or old_network != params["network"]: networks = self._expand_alias([old_network]) for network in networks: if self.pfsense.is_ipv4_address(network): - network = network + '/32' + network = network + "/32" elif self.pfsense.is_ipv6_address(old_network): - network = network + '/128' + network = network + "/128" if self.pfsense.is_ipv4_network(network, False): - family = '-inet' + family = "-inet" else: - family = '-inet6' + family = "-inet6" - self.route_cmd.append('/sbin/route delete {0} {1}'.format(family, network)) + self.route_cmd.append( + "/sbin/route delete {0} {1}".format(family, network) + ) return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params['state'] == 'present': - gw_elt = self.pfsense.find_gateway_elt(params['gateway'], dhcp=True, vti=True) + if params["state"] == "present": + gw_elt = self.pfsense.find_gateway_elt( + params["gateway"], dhcp=True, vti=True + ) if gw_elt is None: - self.module.fail_json(msg='The gateway {0} does not exist'.format(params['gateway'])) - - if (self.pfsense.is_ipv4_address(params['network']) and gw_elt.find('ipprotocol').text == 'inet6' or - self.pfsense.is_ipv6_address(params['network']) and gw_elt.find('ipprotocol').text == 'inet'): - msg = 'The gateway "{0}" is a different Address Family than network "{1}".'.format(gw_elt.find('gateway').text, params['network']) + self.module.fail_json( + msg="The gateway {0} does not exist".format(params["gateway"]) + ) + + if ( + self.pfsense.is_ipv4_address(params["network"]) + and gw_elt.find("ipprotocol").text == "inet6" + or self.pfsense.is_ipv6_address(params["network"]) + and gw_elt.find("ipprotocol").text == "inet" + ): + msg = 'The gateway "{0}" is a different Address Family than network "{1}".'.format( + gw_elt.find("gateway").text, params["network"] + ) self.module.fail_json(msg=msg) - if (not self.pfsense.is_ip_network(params['network'], False) and self.pfsense.find_alias(params['network'], aliastype='host') is None and - self.pfsense.find_alias(params['network'], aliastype='network') is None): - self.module.fail_json(msg='A valid IPv4 or IPv6 destination network or alias must be specified.') + if ( + not self.pfsense.is_ip_network(params["network"], False) + and self.pfsense.find_alias(params["network"], aliastype="host") is None + and self.pfsense.find_alias(params["network"], aliastype="network") + is None + ): + self.module.fail_json( + msg="A valid IPv4 or IPv6 destination network or alias must be specified." + ) ############################## # XML processing # def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('route') + """create the XML target_elt""" + return self.pfsense.new_element("route") def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" for route_elt in self.root_elt: - if route_elt.find('descr').text == self.obj['descr']: + if route_elt.find("descr").text == self.obj["descr"]: return route_elt return None @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - return ['disabled'] + """returns the list of params to remove if they are not set""" + return ["disabled"] ############################## # run # def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" for cmd in self.route_cmd: self.module.run_command(cmd) - return self.pfsense.phpshell(''' + return self.pfsense.phpshell( + """ require_once("filter.inc"); $retval = 0; if (file_exists("{$g['tmp_path']}/.system_routes.apply")) { @@ -157,25 +179,39 @@ def _update(self): setup_gateways_monitor(); if ($retval == 0) clear_subsystem_dirty('staticroutes'); -''') +""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.obj['descr']) + """return obj's name""" + return "'{0}'".format(self.obj["descr"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'network') - values += self.format_cli_field(self.obj, 'gateway') - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool, default=False) + values += self.format_cli_field(self.obj, "network") + values += self.format_cli_field(self.obj, "gateway") + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool, default=False + ) else: - values += self.format_updated_cli_field(self.obj, before, 'network', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'gateway', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'disabled', fvalue=self.fvalue_bool, default=False, add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "network", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "gateway", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "disabled", + fvalue=self.fvalue_bool, + default=False, + add_comma=(values), + ) return values diff --git a/plugins/module_utils/rule.py b/plugins/module_utils/rule.py index 44f19b55..4814b212 100644 --- a/plugins/module_utils/rule.py +++ b/plugins/module_utils/rule.py @@ -4,42 +4,66 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import time import re -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) RULE_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - action=dict(default='pass', choices=['pass', 'block', 'match', 'reject']), - state=dict(default='present', choices=['present', 'absent']), - disabled=dict(default=False, required=False, type='bool'), - interface=dict(required=True, type='str'), - floating=dict(required=False, type='bool'), + name=dict(required=True, type="str"), + action=dict(default="pass", choices=["pass", "block", "match", "reject"]), + state=dict(default="present", choices=["present", "absent"]), + disabled=dict(default=False, required=False, type="bool"), + interface=dict(required=True, type="str"), + floating=dict(required=False, type="bool"), direction=dict(required=False, choices=["any", "in", "out"]), - ipprotocol=dict(default='inet', choices=['inet', 'inet46', 'inet6']), - protocol=dict(default='any', choices=["any", "tcp", "udp", "tcp/udp", "icmp", "igmp", "ospf", "esp", "ah", "gre", "pim", "sctp", "pfsync", "carp"]), - source=dict(required=False, type='str'), - source_port=dict(required=False, type='str'), - destination=dict(required=False, type='str'), - destination_port=dict(required=False, type='str'), - log=dict(required=False, type='bool'), - after=dict(required=False, type='str'), - before=dict(required=False, type='str'), - tcpflags_any=dict(required=False, type='bool'), - statetype=dict(default='keep state', choices=['keep state', 'sloppy state', 'synproxy state', 'none']), - queue=dict(required=False, type='str'), - ackqueue=dict(required=False, type='str'), - in_queue=dict(required=False, type='str'), - out_queue=dict(required=False, type='str'), - queue_error=dict(default=True, type='bool'), - gateway=dict(default='default', type='str'), - tracker=dict(required=False, type='str'), - icmptype=dict(default='any', required=False, type='str'), - sched=dict(required=False, type='str'), - quick=dict(default=False, type='bool'), + ipprotocol=dict(default="inet", choices=["inet", "inet46", "inet6"]), + protocol=dict( + default="any", + choices=[ + "any", + "tcp", + "udp", + "tcp/udp", + "icmp", + "igmp", + "ospf", + "esp", + "ah", + "gre", + "pim", + "sctp", + "pfsync", + "carp", + ], + ), + source=dict(required=False, type="str"), + source_port=dict(required=False, type="str"), + destination=dict(required=False, type="str"), + destination_port=dict(required=False, type="str"), + log=dict(required=False, type="bool"), + after=dict(required=False, type="str"), + before=dict(required=False, type="str"), + tcpflags_any=dict(required=False, type="bool"), + statetype=dict( + default="keep state", + choices=["keep state", "sloppy state", "synproxy state", "none"], + ), + queue=dict(required=False, type="str"), + ackqueue=dict(required=False, type="str"), + in_queue=dict(required=False, type="str"), + out_queue=dict(required=False, type="str"), + queue_error=dict(default=True, type="bool"), + gateway=dict(default="default", type="str"), + tracker=dict(required=False, type="str"), + icmptype=dict(default="any", required=False, type="str"), + sched=dict(required=False, type="str"), + quick=dict(default=False, type="bool"), ) RULE_REQUIRED_IF = [ @@ -50,17 +74,27 @@ # These are rule elements that are (currently) unmanaged by this module RULE_UNMANAGED_ELEMENTS = [ - 'created', 'id', 'max', 'max-src-conn', 'max-src-nodes', 'max-src-states', 'os', - 'statetimeout', 'statetype', 'tag', 'tagged', 'updated' + "created", + "id", + "max", + "max-src-conn", + "max-src-nodes", + "max-src-states", + "os", + "statetimeout", + "statetype", + "tag", + "tagged", + "updated", ] class PFSenseRuleModule(PFSenseModuleBase): - """ module managing pfsense rules """ + """module managing pfsense rules""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return RULE_ARGUMENT_SPEC ############################## @@ -71,18 +105,18 @@ def __init__(self, module, pfsense=None): self.name = "pfsense_rule" # Override for use with aggregate self.argument_spec = RULE_ARGUMENT_SPEC - self.root_elt = self.pfsense.get_element('filter') + self.root_elt = self.pfsense.get_element("filter") self.obj = dict() - self.result['added'] = [] - self.result['deleted'] = [] - self.result['modified'] = [] + self.result["added"] = [] + self.result["deleted"] = [] + self.result["modified"] = [] self.obj = None - self._floating = None # are we on floating rule - self._floating_interfaces = None # rule's interfaces before change - self._after = None # insert/move after - self._before = None # insert/move before + self._floating = None # are we on floating rule + self._floating_interfaces = None # rule's interfaces before change + self._after = None # insert/move after + self._before = None # insert/move before self._position_changed = False self.trackers = set() @@ -91,248 +125,378 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() self.obj = obj - obj['descr'] = params['name'] + obj["descr"] = params["name"] - if params.get('floating'): - obj['floating'] = 'yes' - obj['interface'] = self._parse_floating_interfaces(params['interface']) + if params.get("floating"): + obj["floating"] = "yes" + obj["interface"] = self._parse_floating_interfaces(params["interface"]) else: - obj['interface'] = self.pfsense.parse_interface(params['interface']) - - if params['state'] == 'present': - obj['type'] = params['action'] - obj['ipprotocol'] = params['ipprotocol'] - obj['statetype'] = params['statetype'] - - obj['source'] = self.pfsense.parse_address(params['source']) - if params.get('source_port'): - self.pfsense.parse_port(params['source_port'], obj['source']) - - obj['destination'] = self.pfsense.parse_address(params['destination']) - if params.get('destination_port'): - self.pfsense.parse_port(params['destination_port'], obj['destination']) - - if params['protocol'] not in ['tcp', 'udp', 'tcp/udp'] and ('port' in obj['source'] or 'port' in obj['destination']): - self.module.fail_json(msg="{0}: you can't use ports on protocols other than tcp, udp or tcp/udp".format(self._get_obj_name())) - - for param in ['destination', 'source']: - if 'address' in obj[param]: - self.pfsense.check_ip_address(obj[param]['address'], obj['ipprotocol'], 'rule') - if 'network' in obj[param]: - self.pfsense.check_ip_address(obj[param]['network'], obj['ipprotocol'], 'rule', allow_networks=True) - - self._get_ansible_param(obj, 'protocol', exclude='any') - if params['protocol'] == 'icmp': - self._get_ansible_param(obj, 'icmptype') - self._get_ansible_param(obj, 'direction') - self._get_ansible_param(obj, 'queue', fname='defaultqueue') - if params.get('ackqueue'): - self._get_ansible_param(obj, 'ackqueue') - self._get_ansible_param(obj, 'in_queue', fname='dnpipe') - self._get_ansible_param(obj, 'out_queue', fname='pdnpipe') - self._get_ansible_param(obj, 'associated-rule-id') - self._get_ansible_param(obj, 'tracker', exclude='') - self._get_ansible_param(obj, 'gateway', exclude='default') - self._get_ansible_param(obj, 'sched') - - self._get_ansible_param_bool(obj, 'disabled', value='') - self._get_ansible_param_bool(obj, 'log', value='') - self._get_ansible_param_bool(obj, 'quick') - self._get_ansible_param_bool(obj, 'tcpflags_any', value='') - - self._floating = 'floating' in self.obj and self.obj['floating'] == 'yes' - self._after = params.get('after') - self._before = params.get('before') + obj["interface"] = self.pfsense.parse_interface(params["interface"]) + + if params["state"] == "present": + obj["type"] = params["action"] + obj["ipprotocol"] = params["ipprotocol"] + obj["statetype"] = params["statetype"] + + obj["source"] = self.pfsense.parse_address(params["source"]) + if params.get("source_port"): + self.pfsense.parse_port(params["source_port"], obj["source"]) + + obj["destination"] = self.pfsense.parse_address(params["destination"]) + if params.get("destination_port"): + self.pfsense.parse_port(params["destination_port"], obj["destination"]) + + if params["protocol"] not in ["tcp", "udp", "tcp/udp"] and ( + "port" in obj["source"] or "port" in obj["destination"] + ): + self.module.fail_json( + msg="{0}: you can't use ports on protocols other than tcp, udp or tcp/udp".format( + self._get_obj_name() + ) + ) + + for param in ["destination", "source"]: + if "address" in obj[param]: + self.pfsense.check_ip_address( + obj[param]["address"], obj["ipprotocol"], "rule" + ) + if "network" in obj[param]: + self.pfsense.check_ip_address( + obj[param]["network"], + obj["ipprotocol"], + "rule", + allow_networks=True, + ) + + self._get_ansible_param(obj, "protocol", exclude="any") + if params["protocol"] == "icmp": + self._get_ansible_param(obj, "icmptype") + self._get_ansible_param(obj, "direction") + self._get_ansible_param(obj, "queue", fname="defaultqueue") + if params.get("ackqueue"): + self._get_ansible_param(obj, "ackqueue") + self._get_ansible_param(obj, "in_queue", fname="dnpipe") + self._get_ansible_param(obj, "out_queue", fname="pdnpipe") + self._get_ansible_param(obj, "associated-rule-id") + self._get_ansible_param(obj, "tracker", exclude="") + self._get_ansible_param(obj, "gateway", exclude="default") + self._get_ansible_param(obj, "sched") + + self._get_ansible_param_bool(obj, "disabled", value="") + self._get_ansible_param_bool(obj, "log", value="") + self._get_ansible_param_bool(obj, "quick") + self._get_ansible_param_bool(obj, "tcpflags_any", value="") + + self._floating = "floating" in self.obj and self.obj["floating"] == "yes" + self._after = params.get("after") + self._before = params.get("before") return obj def _parse_floating_interfaces(self, interfaces): - """ validate param interface field when floating is true """ + """validate param interface field when floating is true""" res = [] - for interface in interfaces.split(','): - if interface == 'any': + for interface in interfaces.split(","): + if interface == "any": res.append(interface) else: res.append(self.pfsense.parse_interface(interface)) self._floating_interfaces = interfaces - return ','.join(res) + return ",".join(res) def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params.get('ackqueue') and params['queue'] is None: - self.module.fail_json(msg='A default queue must be selected when an acknowledge queue is also selected') + if params.get("ackqueue") and params["queue"] is None: + self.module.fail_json( + msg="A default queue must be selected when an acknowledge queue is also selected" + ) - if params.get('ackqueue') and params['ackqueue'] == params['queue']: - self.module.fail_json(msg='Acknowledge queue and default queue cannot be the same') + if params.get("ackqueue") and params["ackqueue"] == params["queue"]: + self.module.fail_json( + msg="Acknowledge queue and default queue cannot be the same" + ) # as in pfSense 2.4, the GUI accepts any queue defined on any interface without checking, we do the same - if params.get('ackqueue') and self.pfsense.find_queue(params['ackqueue'], enabled=True) is None and params['queue_error']: - self.module.fail_json(msg='Failed to find enabled ackqueue=%s' % params['ackqueue']) - - if params.get('queue') is not None and self.pfsense.find_queue(params['queue'], enabled=True) is None and params['queue_error']: - self.module.fail_json(msg='Failed to find enabled queue=%s' % params['queue']) - - if params.get('out_queue') is not None and params['in_queue'] is None: - self.module.fail_json(msg='A queue must be selected for the In direction before selecting one for Out too') - - if params.get('out_queue') is not None and params['out_queue'] == params['in_queue']: - self.module.fail_json(msg='In and Out Queue cannot be the same') - - if params.get('out_queue') is not None and self.pfsense.find_limiter(params['out_queue'], enabled=True) is None: - self.module.fail_json(msg='Failed to find enabled out_queue=%s' % params['out_queue']) - - if params.get('in_queue') is not None and self.pfsense.find_limiter(params['in_queue'], enabled=True) is None: - self.module.fail_json(msg='Failed to find enabled in_queue=%s' % params['in_queue']) - - if params.get('floating') and params.get('direction') == 'any' and (params['in_queue'] is not None or params['out_queue'] is not None): - self.module.fail_json(msg='Limiters can not be used in Floating rules without choosing a direction') - - if params.get('after') and params.get('before'): - self.module.fail_json(msg='Cannot specify both after and before') - elif params.get('after'): - if params['after'] == params['name']: - self.module.fail_json(msg='Cannot specify the current rule in after') - elif params.get('before'): - if params['before'] == params['name']: - self.module.fail_json(msg='Cannot specify the current rule in before') + if ( + params.get("ackqueue") + and self.pfsense.find_queue(params["ackqueue"], enabled=True) is None + and params["queue_error"] + ): + self.module.fail_json( + msg="Failed to find enabled ackqueue=%s" % params["ackqueue"] + ) + + if ( + params.get("queue") is not None + and self.pfsense.find_queue(params["queue"], enabled=True) is None + and params["queue_error"] + ): + self.module.fail_json( + msg="Failed to find enabled queue=%s" % params["queue"] + ) + + if params.get("out_queue") is not None and params["in_queue"] is None: + self.module.fail_json( + msg="A queue must be selected for the In direction before selecting one for Out too" + ) + + if ( + params.get("out_queue") is not None + and params["out_queue"] == params["in_queue"] + ): + self.module.fail_json(msg="In and Out Queue cannot be the same") + + if ( + params.get("out_queue") is not None + and self.pfsense.find_limiter(params["out_queue"], enabled=True) is None + ): + self.module.fail_json( + msg="Failed to find enabled out_queue=%s" % params["out_queue"] + ) + + if ( + params.get("in_queue") is not None + and self.pfsense.find_limiter(params["in_queue"], enabled=True) is None + ): + self.module.fail_json( + msg="Failed to find enabled in_queue=%s" % params["in_queue"] + ) + + if ( + params.get("floating") + and params.get("direction") == "any" + and (params["in_queue"] is not None or params["out_queue"] is not None) + ): + self.module.fail_json( + msg="Limiters can not be used in Floating rules without choosing a direction" + ) + + if params.get("after") and params.get("before"): + self.module.fail_json(msg="Cannot specify both after and before") + elif params.get("after"): + if params["after"] == params["name"]: + self.module.fail_json(msg="Cannot specify the current rule in after") + elif params.get("before"): + if params["before"] == params["name"]: + self.module.fail_json(msg="Cannot specify the current rule in before") # gateway - if params.get('gateway') is not None and params['gateway'] != 'default': - if params['ipprotocol'] == 'inet46': - self.module.fail_json(msg='Gateway selection is not valid for "IPV4+IPV6" address family.') - elif (self.pfsense.find_gateway_group_elt(params['gateway'], params['ipprotocol']) is None - and self.pfsense.find_gateway_elt(params['gateway'], None, params['ipprotocol']) is None): - self.module.fail_json(msg='Gateway "%s" does not exist or does not match target rule ip protocol.' % params['gateway']) - - if params.get('floating') and params.get('direction') == 'any': - self.module.fail_json(msg='Gateways can not be used in Floating rules without choosing a direction') + if params.get("gateway") is not None and params["gateway"] != "default": + if params["ipprotocol"] == "inet46": + self.module.fail_json( + msg='Gateway selection is not valid for "IPV4+IPV6" address family.' + ) + elif ( + self.pfsense.find_gateway_group_elt( + params["gateway"], params["ipprotocol"] + ) + is None + and self.pfsense.find_gateway_elt( + params["gateway"], None, params["ipprotocol"] + ) + is None + ): + self.module.fail_json( + msg='Gateway "%s" does not exist or does not match target rule ip protocol.' + % params["gateway"] + ) + + if params.get("floating") and params.get("direction") == "any": + self.module.fail_json( + msg="Gateways can not be used in Floating rules without choosing a direction" + ) # tracker - if params.get('tracker') is not None and int(params['tracker']) < 0: - self.module.fail_json(msg='tracker {0} must be a positive integer'.format(params['tracker'])) + if params.get("tracker") is not None and int(params["tracker"]) < 0: + self.module.fail_json( + msg="tracker {0} must be a positive integer".format(params["tracker"]) + ) # sched - if params.get('sched') is not None and self.pfsense.find_schedule_elt(params['sched']) is None: - self.module.fail_json(msg='Schedule {0} does not exist'.format(params['sched'])) + if ( + params.get("sched") is not None + and self.pfsense.find_schedule_elt(params["sched"]) is None + ): + self.module.fail_json( + msg="Schedule {0} does not exist".format(params["sched"]) + ) # quick - if params.get('quick') and not params.get('floating'): - self.module.fail_json(msg='quick can only be used on floating rules') + if params.get("quick") and not params.get("floating"): + self.module.fail_json(msg="quick can only be used on floating rules") # ICMP - if params.get('protocol') == 'icmp' and params.get('icmptype') is not None: - both_types = ['any', 'echorep', 'echoreq', 'paramprob', 'redir', 'routeradv', 'routersol', 'timex', 'unreach'] - v4_types = ['althost', 'dataconv', 'inforep', 'inforeq', 'ipv6-here', 'ipv6-where', 'maskrep', 'maskreq', 'mobredir', 'mobregrep', 'mobregreq'] - v4_types += ['photuris', 'skip', 'squench', 'timerep', 'timereq', 'trace'] - v6_types = ['fqdnrep', 'fqdnreq', 'groupqry', 'grouprep', 'groupterm', 'listendone', 'listenrep', 'listqry', 'mtrace', 'mtraceresp', 'neighbradv'] - v6_types += ['neighbrsol', 'niqry', 'nirep', 'routrrenum', 'toobig', 'wrurep', 'wrureq'] - - icmptypes = list(set(map(str.strip, params['icmptype'].split(',')))) + if params.get("protocol") == "icmp" and params.get("icmptype") is not None: + both_types = [ + "any", + "echorep", + "echoreq", + "paramprob", + "redir", + "routeradv", + "routersol", + "timex", + "unreach", + ] + v4_types = [ + "althost", + "dataconv", + "inforep", + "inforeq", + "ipv6-here", + "ipv6-where", + "maskrep", + "maskreq", + "mobredir", + "mobregrep", + "mobregreq", + ] + v4_types += ["photuris", "skip", "squench", "timerep", "timereq", "trace"] + v6_types = [ + "fqdnrep", + "fqdnreq", + "groupqry", + "grouprep", + "groupterm", + "listendone", + "listenrep", + "listqry", + "mtrace", + "mtraceresp", + "neighbradv", + ] + v6_types += [ + "neighbrsol", + "niqry", + "nirep", + "routrrenum", + "toobig", + "wrurep", + "wrureq", + ] + + icmptypes = list(set(map(str.strip, params["icmptype"].split(",")))) icmptypes.sort() - if '' in icmptypes: - icmptypes.remove('') + if "" in icmptypes: + icmptypes.remove("") if len(icmptypes) == 0: - self.module.fail_json(msg='You must specify at least one icmptype or any for all of them') + self.module.fail_json( + msg="You must specify at least one icmptype or any for all of them" + ) invalids = set(icmptypes) - set(v4_types) - set(v6_types) - set(both_types) if len(invalids) > 0: - self.module.fail_json(msg='ICMP types {0} does not exist'.format(','.join(invalids))) + self.module.fail_json( + msg="ICMP types {0} does not exist".format(",".join(invalids)) + ) - if params['ipprotocol'] == 'inet': + if params["ipprotocol"] == "inet": left = set(icmptypes) - set(v4_types) - set(both_types) - elif params['ipprotocol'] == 'inet6': + elif params["ipprotocol"] == "inet6": left = set(icmptypes) - set(v6_types) - set(both_types) - else: # inet46 only allow + else: # inet46 only allow left = set(icmptypes) - set(both_types) if len(left) > 0: - self.module.fail_json(msg='ICMP types {0} are invalid with IP type {1}'.format(','.join(left), params['ipprotocol'])) - params['icmptype'] = ','.join(icmptypes) + self.module.fail_json( + msg="ICMP types {0} are invalid with IP type {1}".format( + ",".join(left), params["ipprotocol"] + ) + ) + params["icmptype"] = ",".join(icmptypes) ############################## # XML processing # def _adjust_separators(self, start_idx, add=True, before=False): - """ update separators position """ - separators_elt = self.root_elt.find('separator') + """update separators position""" + separators_elt = self.root_elt.find("separator") if separators_elt is None: return - separators_elt = separators_elt.find(self.obj['interface']) + separators_elt = separators_elt.find(self.obj["interface"]) if separators_elt is None: return for separator_elt in separators_elt: - row_elt = separator_elt.find('row') + row_elt = separator_elt.find("row") if row_elt is None or row_elt.text is None: continue - if_elt = separator_elt.find('if') - if if_elt is None or if_elt.text != self.obj['interface']: + if_elt = separator_elt.find("if") + if if_elt is None or if_elt.text != self.obj["interface"]: continue - match = re.match(r'fr(\d+)', row_elt.text) + match = re.match(r"fr(\d+)", row_elt.text) if match: idx = int(match.group(1)) if add: if before: if idx > start_idx: - row_elt.text = 'fr' + str(idx + 1) + row_elt.text = "fr" + str(idx + 1) else: if idx >= start_idx: - row_elt.text = 'fr' + str(idx + 1) + row_elt.text = "fr" + str(idx + 1) elif idx > start_idx: - row_elt.text = 'fr' + str(idx - 1) + row_elt.text = "fr" + str(idx - 1) def _check_tracker(self): - """ check the tracking used is unique and change it if required """ + """check the tracking used is unique and change it if required""" if not self.trackers: - trackers = self.root_elt.findall('tracker') + trackers = self.root_elt.findall("tracker") for tracker in trackers: self.trackers.add(tracker.text) start = int(time.time()) - while self.obj['tracker'] in self.trackers: + while self.obj["tracker"] in self.trackers: start = start + 1 - self.obj['tracker'] = str(start) + self.obj["tracker"] = str(start) # keep the tracker for future calls if module is used with aggregate - self.trackers.add(self.obj['tracker']) + self.trackers.add(self.obj["tracker"]) def _copy_and_add_target(self): - """ create the XML target_elt """ - timestamp = '%d' % int(time.time()) - self.obj['id'] = '' - if 'tracker' not in self.obj: - self.obj['tracker'] = timestamp + """create the XML target_elt""" + timestamp = "%d" % int(time.time()) + self.obj["id"] = "" + if "tracker" not in self.obj: + self.obj["tracker"] = timestamp self._check_tracker() - self.obj['created'] = self.obj['updated'] = dict() - self.obj['created']['time'] = self.obj['updated']['time'] = timestamp - self.obj['created']['username'] = self.obj['updated']['username'] = self.pfsense.get_username() + self.obj["created"] = self.obj["updated"] = dict() + self.obj["created"]["time"] = self.obj["updated"]["time"] = timestamp + self.obj["created"]["username"] = self.obj["updated"]["username"] = ( + self.pfsense.get_username() + ) self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self._rule_element_to_dict() + self.diff["after"] = self._rule_element_to_dict() self._insert(self.target_elt) - self.result['added'].append(self.obj) + self.result["added"].append(self.obj) def _copy_and_update_target(self): - """ update the XML target_elt """ - timestamp = '%d' % int(time.time()) + """update the XML target_elt""" + timestamp = "%d" % int(time.time()) before = self._rule_element_to_dict() - if 'tracker' not in self.obj: - self.obj['tracker'] = before['tracker'] - - if 'associated-rule-id' not in self.obj and 'associated-rule-id' in before and before['associated-rule-id'] != '': - self.module.fail_json(msg='Target filter rule is associated with a NAT rule.') - - self.diff['before'] = before + if "tracker" not in self.obj: + self.obj["tracker"] = before["tracker"] + + if ( + "associated-rule-id" not in self.obj + and "associated-rule-id" in before + and before["associated-rule-id"] != "" + ): + self.module.fail_json( + msg="Target filter rule is associated with a NAT rule." + ) + + self.diff["before"] = before changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) if self._remove_deleted_params(): changed = True @@ -341,42 +505,46 @@ def _copy_and_update_target(self): changed = True if changed: - updated_elt = self.target_elt.find('updated') + updated_elt = self.target_elt.find("updated") if updated_elt is None: - updated_elt = self.pfsense.new_element('updated') - updated_elt.append(self.pfsense.new_element('time', timestamp)) - updated_elt.append(self.pfsense.new_element('username', self.pfsense.get_username())) + updated_elt = self.pfsense.new_element("updated") + updated_elt.append(self.pfsense.new_element("time", timestamp)) + updated_elt.append( + self.pfsense.new_element("username", self.pfsense.get_username()) + ) self.target_elt.append(updated_elt) else: - updated_elt.find('time').text = timestamp - updated_elt.find('username').text = self.pfsense.get_username() - self.diff['after'].update(self._rule_element_to_dict()) - self.result['modified'].append(self._rule_element_to_dict()) + updated_elt.find("time").text = timestamp + updated_elt.find("username").text = self.pfsense.get_username() + self.diff["after"].update(self._rule_element_to_dict()) + self.result["modified"].append(self._rule_element_to_dict()) return (before, changed) def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('rule') + """create the XML target_elt""" + return self.pfsense.new_element("rule") def _find_matching_rule(self): - """ return rule element and index that matches by description or action """ + """return rule element and index that matches by description or action""" # Prioritize matching my name - if 'associated-rule-id' in self.obj: - found, i = self._find_rule(self.obj['associated-rule-id'], 'associated-rule-id') + if "associated-rule-id" in self.obj: + found, i = self._find_rule( + self.obj["associated-rule-id"], "associated-rule-id" + ) if found is not None: return (found, i) - found, i = self._find_rule(self.obj['descr']) + found, i = self._find_rule(self.obj["descr"]) if found is not None: return (found, i) # Match action without name/descr match_rule = self.obj.copy() - del match_rule['descr'] + del match_rule["descr"] for rule_elt in self.root_elt: this_rule = self.pfsense.element_to_dict(rule_elt) - this_rule.pop('descr', None) + this_rule.pop("descr", None) # Remove unmanaged elements for unwanted in RULE_UNMANAGED_ELEMENTS: this_rule.pop(unwanted, None) @@ -386,30 +554,44 @@ def _find_matching_rule(self): return (None, -1) - def _find_rule(self, value, field='descr'): - """ return rule element and index on interface/floating that matches criteria """ + def _find_rule(self, value, field="descr"): + """return rule element and index on interface/floating that matches criteria""" i = 0 for rule_elt in self.root_elt: field_elt = rule_elt.find(field) - if self._match_interface(rule_elt) and field_elt is not None and field_elt.text == value: + if ( + self._match_interface(rule_elt) + and field_elt is not None + and field_elt.text == value + ): return (rule_elt, i) i += 1 return (None, -1) def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" rule_elt, dummy = self._find_matching_rule() if rule_elt is not None and self._floating: - ifs_elt = rule_elt.find('interface') - self._floating_interfaces = ','.join([self.pfsense.get_interface_display_name(interface) for interface in ifs_elt.text.split(',')]) + ifs_elt = rule_elt.find("interface") + self._floating_interfaces = ",".join( + [ + self.pfsense.get_interface_display_name(interface) + for interface in ifs_elt.text.split(",") + ] + ) return rule_elt def _get_expected_rule_position(self): - """ get expected rule position in interface/floating """ - if self._before == 'bottom': - return self.pfsense.get_interface_rules_count(self.obj['interface'], self._floating) - 1 - elif self._after == 'top': + """get expected rule position in interface/floating""" + if self._before == "bottom": + return ( + self.pfsense.get_interface_rules_count( + self.obj["interface"], self._floating + ) + - 1 + ) + elif self._after == "top": return 0 elif self._after is not None: return self._get_rule_position(self._after, first=False) + 1 @@ -422,36 +604,44 @@ def _get_expected_rule_position(self): position = self._get_rule_position(self._after, fail=False) if position is not None: return position - return self.pfsense.get_interface_rules_count(self.obj['interface'], self._floating) + return self.pfsense.get_interface_rules_count( + self.obj["interface"], self._floating + ) return -1 def _get_expected_rule_xml_index(self): - """ get expected rule index in xml """ - if self._before == 'bottom': + """get expected rule index in xml""" + if self._before == "bottom": return self._get_last_rule_xml_index() + 1 - elif self._after == 'top': + elif self._after == "top": return self._get_first_rule_xml_index() elif self._after is not None: found, i = self._find_rule(self._after) if found is not None: return i + 1 else: - self.module.fail_json(msg='Failed to insert after rule=%s interface=%s' % (self._after, self._interface_name())) + self.module.fail_json( + msg="Failed to insert after rule=%s interface=%s" + % (self._after, self._interface_name()) + ) elif self._before is not None: found, i = self._find_rule(self._before) if found is not None: return i else: - self.module.fail_json(msg='Failed to insert before rule=%s interface=%s' % (self._before, self._interface_name())) + self.module.fail_json( + msg="Failed to insert before rule=%s interface=%s" + % (self._before, self._interface_name()) + ) else: - found, i = self._find_rule(self.obj['descr']) + found, i = self._find_rule(self.obj["descr"]) if found is not None: return i return self._get_last_rule_xml_index() + 1 return -1 def _get_first_rule_xml_index(self): - """ Find the first rule for the interface/floating and return its xml index """ + """Find the first rule for the interface/floating and return its xml index""" i = 0 for rule_elt in self.root_elt: if self._match_interface(rule_elt): @@ -460,7 +650,7 @@ def _get_first_rule_xml_index(self): return i def _get_last_rule_xml_index(self): - """ Find the last rule for the interface/floating and return its xml index """ + """Find the last rule for the interface/floating and return its xml index""" last_found = -1 i = 0 for rule_elt in self.root_elt: @@ -471,41 +661,63 @@ def _get_last_rule_xml_index(self): @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - return ['log', 'protocol', 'disabled', 'defaultqueue', 'ackqueue', 'dnpipe', 'pdnpipe', 'gateway', 'icmptype', 'sched', 'quick', 'tcpflags_any'] + """returns the list of params to remove if they are not set""" + return [ + "log", + "protocol", + "disabled", + "defaultqueue", + "ackqueue", + "dnpipe", + "pdnpipe", + "gateway", + "icmptype", + "sched", + "quick", + "tcpflags_any", + ] def _get_rule_position(self, descr=None, fail=True, first=True): - """ get rule position in interface/floating """ + """get rule position in interface/floating""" if descr is None: - descr = self.obj['descr'] + descr = self.obj["descr"] - res = self.pfsense.get_rule_position(descr, self.obj['interface'], self._floating, first=first) + res = self.pfsense.get_rule_position( + descr, self.obj["interface"], self._floating, first=first + ) if fail and res is None: - self.module.fail_json(msg='Failed to find rule=%s interface=%s' % (descr, self._interface_name())) + self.module.fail_json( + msg="Failed to find rule=%s interface=%s" + % (descr, self._interface_name()) + ) return res def _insert(self, rule_elt): - """ insert rule into xml """ + """insert rule into xml""" rule_xml_idx = self._get_expected_rule_xml_index() self.root_elt.insert(rule_xml_idx, rule_elt) rule_position = self._get_rule_position() - self._adjust_separators(rule_position, before=(self._after is None and self._before is not None)) + self._adjust_separators( + rule_position, before=(self._after is None and self._before is not None) + ) def _match_interface(self, rule_elt): - """ check if a rule elt match the targeted interface """ - return self.pfsense.rule_match_interface(rule_elt, self.obj['interface'], self._floating) + """check if a rule elt match the targeted interface""" + return self.pfsense.rule_match_interface( + rule_elt, self.obj["interface"], self._floating + ) def _update_rule_position(self, rule_elt): - """ move rule in xml if required """ + """move rule in xml if required""" current_position = self._get_rule_position() expected_position = self._get_expected_rule_position() if current_position == expected_position: self._position_changed = False return False - self.diff['before']['position'] = current_position - self.diff['after']['position'] = expected_position + self.diff["before"]["position"] = current_position + self.diff["after"]["position"] = expected_position self._adjust_separators(current_position, add=False) self.root_elt.remove(rule_elt) self._insert(rule_elt) @@ -516,144 +728,229 @@ def _update_rule_position(self, rule_elt): # run # def _pre_remove_target_elt(self): - """ processing before removing elt """ + """processing before removing elt""" self._adjust_separators(self._get_rule_position(), add=False) - self.diff['before'] = self._rule_element_to_dict() - self.result['deleted'].append(self._rule_element_to_dict()) + self.diff["before"] = self._rule_element_to_dict() + self.result["deleted"].append(self._rule_element_to_dict()) def _rule_element_to_dict(self): - """ convert rule_elt to dictionary like module arguments """ + """convert rule_elt to dictionary like module arguments""" rule = self.pfsense.element_to_dict(self.target_elt) # We use 'name' for 'descr' - rule['name'] = rule.pop('descr', 'UNKNOWN') + rule["name"] = rule.pop("descr", "UNKNOWN") # We use 'action' for 'type' - rule['action'] = rule.pop('type', 'UNKNOWN') + rule["action"] = rule.pop("type", "UNKNOWN") # Convert addresses to argument format - for addr_item in ['source', 'destination']: - rule[addr_item], rule[addr_item + '_port'] = self.pfsense.addr_normalize(rule[addr_item]) + for addr_item in ["source", "destination"]: + rule[addr_item], rule[addr_item + "_port"] = self.pfsense.addr_normalize( + rule[addr_item] + ) return rule def _update(self): - """ make the target pfsense reload rules """ - return self.pfsense.phpshell('''require_once("filter.inc"); -if (filter_configure() == 0) { clear_subsystem_dirty('filter'); }''') + """make the target pfsense reload rules""" + return self.pfsense.phpshell( + """require_once("filter.inc"); +if (filter_configure() == 0) { clear_subsystem_dirty('filter'); }""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}' on '{1}'".format(self.obj['descr'], self._interface_name()) + """return obj's name""" + return "'{0}' on '{1}'".format(self.obj["descr"], self._interface_name()) def _interface_name(self): - """ return formated interface name for logging """ + """return formated interface name for logging""" if self._floating: if self._floating_interfaces is not None: - return 'floating(' + self._floating_interfaces + ')' - return 'floating(' + self.params['interface'] + ')' - return self.params['interface'] + return "floating(" + self._floating_interfaces + ")" + return "floating(" + self.params["interface"] + ")" + return self.params["interface"] def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'source') - values += self.format_cli_field(self.params, 'source_port') - values += self.format_cli_field(self.params, 'destination') - values += self.format_cli_field(self.params, 'destination_port') - values += self.format_cli_field(self.params, 'protocol', default='any') - values += self.format_cli_field(self.params, 'direction') - values += self.format_cli_field(self.params, 'ipprotocol', default='inet') - values += self.format_cli_field(self.params, 'icmptype', default='any') - values += self.format_cli_field(self.params, 'tcpflags_any', fvalue=self.fvalue_bool) - values += self.format_cli_field(self.params, 'statetype', default='keep state') - values += self.format_cli_field(self.params, 'action', default='pass') - values += self.format_cli_field(self.params, 'disabled', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'log', fvalue=self.fvalue_bool, default=False) - values += self.format_cli_field(self.params, 'after') - values += self.format_cli_field(self.params, 'before') - values += self.format_cli_field(self.params, 'queue') - values += self.format_cli_field(self.params, 'ackqueue') - values += self.format_cli_field(self.params, 'in_queue') - values += self.format_cli_field(self.params, 'out_queue') - values += self.format_cli_field(self.params, 'gateway', default='default') - values += self.format_cli_field(self.params, 'tracker') - values += self.format_cli_field(self.params, 'sched') - values += self.format_cli_field(self.params, 'quick', fvalue=self.fvalue_bool, default=False) + values += self.format_cli_field(self.params, "source") + values += self.format_cli_field(self.params, "source_port") + values += self.format_cli_field(self.params, "destination") + values += self.format_cli_field(self.params, "destination_port") + values += self.format_cli_field(self.params, "protocol", default="any") + values += self.format_cli_field(self.params, "direction") + values += self.format_cli_field(self.params, "ipprotocol", default="inet") + values += self.format_cli_field(self.params, "icmptype", default="any") + values += self.format_cli_field( + self.params, "tcpflags_any", fvalue=self.fvalue_bool + ) + values += self.format_cli_field( + self.params, "statetype", default="keep state" + ) + values += self.format_cli_field(self.params, "action", default="pass") + values += self.format_cli_field( + self.params, "disabled", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field( + self.params, "log", fvalue=self.fvalue_bool, default=False + ) + values += self.format_cli_field(self.params, "after") + values += self.format_cli_field(self.params, "before") + values += self.format_cli_field(self.params, "queue") + values += self.format_cli_field(self.params, "ackqueue") + values += self.format_cli_field(self.params, "in_queue") + values += self.format_cli_field(self.params, "out_queue") + values += self.format_cli_field(self.params, "gateway", default="default") + values += self.format_cli_field(self.params, "tracker") + values += self.format_cli_field(self.params, "sched") + values += self.format_cli_field( + self.params, "quick", fvalue=self.fvalue_bool, default=False + ) else: fbefore = self._obj_to_log_fields(before) fafter = self._obj_to_log_fields(self.obj) - fafter['before'] = self._before - fafter['after'] = self._after - - values += self.format_updated_cli_field(fafter, fbefore, 'source', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'source_port', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'destination', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'destination_port', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'protocol', none_value="'any'", add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'icmptype', add_comma=(values)) - values += self.format_updated_cli_field(fafter, fbefore, 'interface', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'floating', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'direction', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'ipprotocol', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'tcpflags_any', fvalue=self.fvalue_bool, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'statetype', add_comma=(values)) - values += self.format_updated_cli_field(self.params, before, 'action', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'disabled', fvalue=self.fvalue_bool, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'log', fvalue=self.fvalue_bool, add_comma=(values)) + fafter["before"] = self._before + fafter["after"] = self._after + + values += self.format_updated_cli_field( + fafter, fbefore, "source", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "source_port", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "destination", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "destination_port", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "protocol", none_value="'any'", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "icmptype", add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, fbefore, "interface", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "floating", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "direction", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "ipprotocol", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "tcpflags_any", + fvalue=self.fvalue_bool, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, before, "statetype", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.params, before, "action", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "disabled", + fvalue=self.fvalue_bool, + add_comma=(values), + ) + values += self.format_updated_cli_field( + self.obj, before, "log", fvalue=self.fvalue_bool, add_comma=(values) + ) if self._position_changed: - values += self.format_updated_cli_field(fafter, {}, 'after', log_none=False, add_comma=(values)) - values += self.format_updated_cli_field(fafter, {}, 'before', log_none=False, add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'defaultqueue', fname='queue', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'ackqueue', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'dnpipe', fname='in_queue', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'pdnpipe', fname='out_queue', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'gateway', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'tracker', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'sched', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'quick', fvalue=self.fvalue_bool, add_comma=(values)) + values += self.format_updated_cli_field( + fafter, {}, "after", log_none=False, add_comma=(values) + ) + values += self.format_updated_cli_field( + fafter, {}, "before", log_none=False, add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "defaultqueue", fname="queue", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "ackqueue", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "dnpipe", fname="in_queue", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "pdnpipe", fname="out_queue", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "gateway", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "tracker", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "sched", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "quick", fvalue=self.fvalue_bool, add_comma=(values) + ) return values def _obj_address_to_log_field(self, rule, addr): - """ return formated address from dict """ - field = '' - field_port = '' + """return formated address from dict""" + field = "" + field_port = "" if isinstance(rule[addr], dict): - if 'not' in rule[addr]: - field += '!' - if 'any' in rule[addr]: - field += 'any' - if 'address' in rule[addr]: - field += rule[addr]['address'] - elif 'network' in rule[addr]: + if "not" in rule[addr]: + field += "!" + if "any" in rule[addr]: + field += "any" + if "address" in rule[addr]: + field += rule[addr]["address"] + elif "network" in rule[addr]: interface = None - if rule[addr]['network'].endswith('ip'): - interface = self.pfsense.get_interface_display_name(rule[addr]['network'][:-2], return_none=True) + if rule[addr]["network"].endswith("ip"): + interface = self.pfsense.get_interface_display_name( + rule[addr]["network"][:-2], return_none=True + ) if interface is None: - field += 'NET:' + self.pfsense.get_interface_display_name(rule[addr]['network']) + field += "NET:" + self.pfsense.get_interface_display_name( + rule[addr]["network"] + ) else: - field += 'IP:' + interface + field += "IP:" + interface - if 'port' in rule[addr]: - field_port += rule[addr]['port'] + if "port" in rule[addr]: + field_port += rule[addr]["port"] else: - if rule[addr].startswith('NET:'): - field = 'NET:' + self.pfsense.get_interface_display_name(rule[addr][4:]) - elif rule[addr].startswith('IP:'): - field = 'IP:' + self.pfsense.get_interface_display_name(rule[addr][3:]) + if rule[addr].startswith("NET:"): + field = "NET:" + self.pfsense.get_interface_display_name(rule[addr][4:]) + elif rule[addr].startswith("IP:"): + field = "IP:" + self.pfsense.get_interface_display_name(rule[addr][3:]) else: field = rule[addr] - field_port = rule[addr + '_port'] + field_port = rule[addr + "_port"] return field, field_port def _obj_to_log_fields(self, rule): - """ return formated source and destination from dict """ + """return formated source and destination from dict""" res = {} - res['source'], res['source_port'] = self._obj_address_to_log_field(rule, 'source') - res['destination'], res['destination_port'] = self._obj_address_to_log_field(rule, 'destination') - res['interface'] = ','.join([self.pfsense.get_interface_display_name(interface) for interface in rule['interface'].split(',')]) + res["source"], res["source_port"] = self._obj_address_to_log_field( + rule, "source" + ) + res["destination"], res["destination_port"] = self._obj_address_to_log_field( + rule, "destination" + ) + res["interface"] = ",".join( + [ + self.pfsense.get_interface_display_name(interface) + for interface in rule["interface"].split(",") + ] + ) return res diff --git a/plugins/module_utils/rule_separator.py b/plugins/module_utils/rule_separator.py index 4696a7c7..a5cb5443 100644 --- a/plugins/module_utils/rule_separator.py +++ b/plugins/module_utils/rule_separator.py @@ -5,29 +5,34 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) RULE_SEPARATOR_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - state=dict(default='present', choices=['present', 'absent']), - interface=dict(required=False, type='str'), - floating=dict(required=False, type='bool'), - color=dict(default='info', required=False, choices=['info', 'warning', 'danger', 'success']), - after=dict(default=None, required=False, type='str'), - before=dict(default=None, required=False, type='str'), + name=dict(required=True, type="str"), + state=dict(default="present", choices=["present", "absent"]), + interface=dict(required=False, type="str"), + floating=dict(required=False, type="bool"), + color=dict( + default="info", required=False, choices=["info", "warning", "danger", "success"] + ), + after=dict(default=None, required=False, type="str"), + before=dict(default=None, required=False, type="str"), ) -RULE_SEPARATOR_REQUIRED_ONE_OF = [['interface', 'floating']] -RULE_SEPARATOR_MUTUALLY_EXCLUSIVE = [['interface', 'floating']] +RULE_SEPARATOR_REQUIRED_ONE_OF = [["interface", "floating"]] +RULE_SEPARATOR_MUTUALLY_EXCLUSIVE = [["interface", "floating"]] class PFSenseRuleSeparatorModule(PFSenseModuleBase): - """ module managing pfsense rule separators """ + """module managing pfsense rule separators""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return RULE_SEPARATOR_ARGUMENT_SPEC ############################## @@ -41,9 +46,9 @@ def __init__(self, module, pfsense=None): self.root_elt = None self.obj = dict() - self.separators = self.pfsense.rules.find('separator') + self.separators = self.pfsense.rules.find("separator") if self.separators is None: - self.separators = self.pfsense.new_element('separator') + self.separators = self.pfsense.new_element("separator") self.pfsense.rules.append(self.separators) self._interface_name = None @@ -55,30 +60,30 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return an separator dict from module params """ + """return an separator dict from module params""" params = self.params - self._floating = (params.get('floating')) - self._after = params.get('after') - self._before = params.get('before') + self._floating = params.get("floating") + self._after = params.get("after") + self._before = params.get("before") obj = dict() self.obj = obj - obj['text'] = params['name'] - if params.get('floating'): - self._interface_name = 'floating' - obj['if'] = 'floatingrules' + obj["text"] = params["name"] + if params.get("floating"): + self._interface_name = "floating" + obj["if"] = "floatingrules" else: - self._interface_name = params['interface'].lower() - obj['if'] = self.pfsense.parse_interface(params['interface']).lower() + self._interface_name = params["interface"].lower() + obj["if"] = self.pfsense.parse_interface(params["interface"]).lower() - if params['state'] == 'present': - obj['color'] = 'bg-' + params['color'] - obj['row'] = 'fr' + str(self._get_expected_separator_position()) + if params["state"] == "present": + obj["color"] = "bg-" + params["color"] + obj["row"] = "fr" + str(self._get_expected_separator_position()) - self.root_elt = self.separators.find(obj['if']) + self.root_elt = self.separators.find(obj["if"]) if self.root_elt is None: - self.root_elt = self.pfsense.new_element(obj['if']) + self.root_elt = self.pfsense.new_element(obj["if"]) self.separators.append(self.root_elt) return obj @@ -87,29 +92,31 @@ def _params_to_obj(self): # XML processing # def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('sep') + """create the XML target_elt""" + return self.pfsense.new_element("sep") def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) self.root_elt.append(self.target_elt) self._recompute_separators_tag() def _find_target(self): - """ find the XML target_elt """ - if_elt = self.separators.find(self.obj['if']) + """find the XML target_elt""" + if_elt = self.separators.find(self.obj["if"]) if if_elt is not None: for separator_elt in if_elt: - if separator_elt.find('text').text == self.obj['text']: + if separator_elt.find("text").text == self.obj["text"]: return separator_elt return None def _get_expected_separator_position(self): - """ get expected separator position in interface/floating """ - if self._before == 'bottom': - return self.pfsense.get_interface_rules_count(self.obj['if'], self._floating) - elif self._after == 'top': + """get expected separator position in interface/floating""" + if self._before == "bottom": + return self.pfsense.get_interface_rules_count( + self.obj["if"], self._floating + ) + elif self._after == "top": return 0 elif self._after is not None: return self._get_rule_position(self._after) + 1 @@ -119,34 +126,39 @@ def _get_expected_separator_position(self): position = self._get_separator_position() if position is not None: return position - return self.pfsense.get_interface_rules_count(self.obj['if'], self._floating) + return self.pfsense.get_interface_rules_count( + self.obj["if"], self._floating + ) return -1 def _get_rule_position(self, descr): - """ get rule position in interface/floating """ - res = self.pfsense.get_rule_position(descr, self.obj['if'], self._floating) + """get rule position in interface/floating""" + res = self.pfsense.get_rule_position(descr, self.obj["if"], self._floating) if res is None: - self.module.fail_json(msg='Failed to find rule=%s interface=%s' % (descr, self._interface_name)) + self.module.fail_json( + msg="Failed to find rule=%s interface=%s" + % (descr, self._interface_name) + ) return res def _get_separator_position(self): - """ get separator position in interface/floating """ + """get separator position in interface/floating""" separator_elt = self._find_target() if separator_elt is not None: - return int(separator_elt.find('row').text.replace('fr', '')) + return int(separator_elt.find("row").text.replace("fr", "")) return None def _post_remove_target_elt(self): - """ processing after removing elt """ + """processing after removing elt""" self._recompute_separators_tag() def _recompute_separators_tag(self): - """ recompute separators tag name """ - if_elt = self.separators.find(self.obj['if']) + """recompute separators tag name""" + if_elt = self.separators.find(self.obj["if"]) if if_elt is not None: i = 0 for separator_elt in if_elt: - name = 'sep' + str(i) + name = "sep" + str(i) if separator_elt.tag != name: separator_elt.tag = name i += 1 @@ -155,26 +167,28 @@ def _recompute_separators_tag(self): # run # def _update(self): - """ make the target pfsense reload separators """ - return self.pfsense.phpshell('''require_once("filter.inc"); -if (filter_configure() == 0) { clear_subsystem_dirty('filter'); }''') + """make the target pfsense reload separators""" + return self.pfsense.phpshell( + """require_once("filter.inc"); +if (filter_configure() == 0) { clear_subsystem_dirty('filter'); }""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}' on '{1}'".format(self.obj['text'], self._interface_name) + """return obj's name""" + return "'{0}' on '{1}'".format(self.obj["text"], self._interface_name) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'color') - values += self.format_cli_field(self.params, 'after') - values += self.format_cli_field(self.params, 'before') + values += self.format_cli_field(self.params, "color") + values += self.format_cli_field(self.params, "after") + values += self.format_cli_field(self.params, "before") else: - values += self.format_cli_field(self.params, 'color', add_comma=(values)) - values += self.format_cli_field(self.params, 'after', add_comma=(values)) - values += self.format_cli_field(self.params, 'before', add_comma=(values)) + values += self.format_cli_field(self.params, "color", add_comma=(values)) + values += self.format_cli_field(self.params, "after", add_comma=(values)) + values += self.format_cli_field(self.params, "before", add_comma=(values)) return values diff --git a/plugins/module_utils/vlan.py b/plugins/module_utils/vlan.py index 7464ccc5..48f7eab4 100644 --- a/plugins/module_utils/vlan.py +++ b/plugins/module_utils/vlan.py @@ -4,24 +4,27 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) VLAN_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - interface=dict(required=True, type='str'), - vlan_id=dict(required=True, type='int'), - priority=dict(default=None, required=False, type='int'), - descr=dict(default='', type='str'), + state=dict(default="present", choices=["present", "absent"]), + interface=dict(required=True, type="str"), + vlan_id=dict(required=True, type="int"), + priority=dict(default=None, required=False, type="int"), + descr=dict(default="", type="str"), ) class PFSenseVlanModule(PFSenseModuleBase): - """ module managing pfsense vlans """ + """module managing pfsense vlans""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return VLAN_ARGUMENT_SPEC ############################## @@ -32,11 +35,11 @@ def __init__(self, module, pfsense=None): self.name = "pfsense_vlan" # Override for use with aggregate self.argument_spec = VLAN_ARGUMENT_SPEC - self.root_elt = self.pfsense.get_element('vlans') + self.root_elt = self.pfsense.get_element("vlans") self.obj = dict() if self.root_elt is None: - self.root_elt = self.pfsense.new_element('vlans') + self.root_elt = self.pfsense.new_element("vlans") self.pfsense.root.append(self.root_elt) self.setup_vlan_cmds = "" @@ -44,33 +47,36 @@ def __init__(self, module, pfsense=None): # get physical interfaces on which vlans can be set get_interface_cmd = ( 'require_once("/etc/inc/interfaces.inc");' - '$portlist = get_interface_list();' - '$lagglist = get_lagg_interface_list();' - '$portlist = array_merge($portlist, $lagglist);' - 'foreach ($lagglist as $laggif => $lagg) {' + "$portlist = get_interface_list();" + "$lagglist = get_lagg_interface_list();" + "$portlist = array_merge($portlist, $lagglist);" + "foreach ($lagglist as $laggif => $lagg) {" " $laggmembers = explode(',', $lagg['members']);" - ' foreach ($laggmembers as $lagm)' - ' if (isset($portlist[$lagm]))' - ' unset($portlist[$lagm]);' - '}') + " foreach ($laggmembers as $lagm)" + " if (isset($portlist[$lagm]))" + " unset($portlist[$lagm]);" + "}" + ) if self.pfsense.is_at_least_2_5_0(): get_interface_cmd += ( - '$list = array();' - 'foreach ($portlist as $ifn => $ifinfo) {' + "$list = array();" + "foreach ($portlist as $ifn => $ifinfo) {" ' $list[$ifn] = $ifn . " (" . $ifinfo["mac"] . ")";' - ' $iface = convert_real_interface_to_friendly_interface_name($ifn);' - ' if (isset($iface) && strlen($iface) > 0)' + " $iface = convert_real_interface_to_friendly_interface_name($ifn);" + " if (isset($iface) && strlen($iface) > 0)" ' $list[$ifn] .= " - $iface";' - '}' - 'echo json_encode($list);') + "}" + "echo json_encode($list);" + ) else: get_interface_cmd += ( - '$list = array();' - 'foreach ($portlist as $ifn => $ifinfo)' - ' if (is_jumbo_capable($ifn))' - ' array_push($list, $ifn);' - 'echo json_encode($list);') + "$list = array();" + "foreach ($portlist as $ifn => $ifinfo)" + " if (is_jumbo_capable($ifn))" + " array_push($list, $ifn);" + "echo json_encode($list);" + ) self.interfaces = self.pfsense.php(get_interface_cmd) @@ -78,112 +84,146 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() - obj['tag'] = str(params['vlan_id']) - if params['interface'] not in self.interfaces: - obj['if'] = self.pfsense.get_interface_port_by_display_name(params['interface']) - if obj['if'] is None: - obj['if'] = self.pfsense.get_interface_port(params['interface']) + obj["tag"] = str(params["vlan_id"]) + if params["interface"] not in self.interfaces: + obj["if"] = self.pfsense.get_interface_port_by_display_name( + params["interface"] + ) + if obj["if"] is None: + obj["if"] = self.pfsense.get_interface_port(params["interface"]) else: - obj['if'] = params['interface'] + obj["if"] = params["interface"] - if params['state'] == 'present': - if params['priority'] is not None: - obj['pcp'] = str(params['priority']) + if params["state"] == "present": + if params["priority"] is not None: + obj["pcp"] = str(params["priority"]) else: - obj['pcp'] = '' + obj["pcp"] = "" - obj['descr'] = params['descr'] - obj['vlanif'] = '{0}.{1}'.format(obj['if'], obj['tag']) + obj["descr"] = params["descr"] + obj["vlanif"] = "{0}.{1}".format(obj["if"], obj["tag"]) return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params # check interface - if params['interface'] not in self.interfaces: + if params["interface"] not in self.interfaces: # check with assign or friendly name - interface = self.pfsense.get_interface_port_by_display_name(params['interface']) + interface = self.pfsense.get_interface_port_by_display_name( + params["interface"] + ) if interface is None: - interface = self.pfsense.get_interface_port(params['interface']) + interface = self.pfsense.get_interface_port(params["interface"]) if interface is None or interface not in self.interfaces: - self.module.fail_json(msg='Vlans can\'t be set on interface {0}'.format(params['interface'])) + self.module.fail_json( + msg="Vlans can't be set on interface {0}".format( + params["interface"] + ) + ) # check vlan - if params['vlan_id'] < 1 or params['vlan_id'] > 4094: - self.module.fail_json(msg='vlan_id must be between 1 and 4094 on interface {0}'.format(params['interface'])) + if params["vlan_id"] < 1 or params["vlan_id"] > 4094: + self.module.fail_json( + msg="vlan_id must be between 1 and 4094 on interface {0}".format( + params["interface"] + ) + ) # check priority - if params.get('priority') is not None and (params['priority'] < 0 or params['priority'] > 7): - self.module.fail_json(msg='priority must be between 0 and 7 on interface {0}'.format(params['interface'])) + if params.get("priority") is not None and ( + params["priority"] < 0 or params["priority"] > 7 + ): + self.module.fail_json( + msg="priority must be between 0 and 7 on interface {0}".format( + params["interface"] + ) + ) ############################## # XML processing # def _cmd_create(self): - """ return the php shell to create the vlan's interface """ + """return the php shell to create the vlan's interface""" cmd = "$vlan = array();\n" - cmd += "$vlan['if'] = '{0}';\n".format(self.obj['if']) - cmd += "$vlan['tag'] = '{0}';\n".format(self.obj['tag']) - cmd += "$vlan['pcp'] = '{0}';\n".format(self.obj['pcp']) - cmd += "$vlan['descr'] = '{0}';\n".format(self.obj['descr']) - cmd += "$vlan['vlanif'] = '{0}';\n".format(self.obj['vlanif']) + cmd += "$vlan['if'] = '{0}';\n".format(self.obj["if"]) + cmd += "$vlan['tag'] = '{0}';\n".format(self.obj["tag"]) + cmd += "$vlan['pcp'] = '{0}';\n".format(self.obj["pcp"]) + cmd += "$vlan['descr'] = '{0}';\n".format(self.obj["descr"]) + cmd += "$vlan['vlanif'] = '{0}';\n".format(self.obj["vlanif"]) cmd += "$vlanif = interface_vlan_configure($vlan);\n" - cmd += "if ($vlanif == NULL || $vlanif != $vlan['vlanif']) {pfSense_interface_destroy('%s');} else {\n" % (self.obj['vlanif']) + cmd += ( + "if ($vlanif == NULL || $vlanif != $vlan['vlanif']) {pfSense_interface_destroy('%s');} else {\n" + % (self.obj["vlanif"]) + ) # if vlan is assigned to an interface, configuration needs to be applied again - interface = self.pfsense.get_interface_by_port('{0}.{1}'.format(self.obj['if'], self.obj['tag'])) + interface = self.pfsense.get_interface_by_port( + "{0}.{1}".format(self.obj["if"], self.obj["tag"]) + ) if interface is not None: cmd += "interface_configure('{0}', true);\n".format(interface) - cmd += '}\n' + cmd += "}\n" return cmd def _copy_and_add_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" super(PFSenseVlanModule, self)._copy_and_add_target() self.setup_vlan_cmds += self._cmd_create() def _copy_and_update_target(self): - """ update the XML target_elt """ - old_vlanif = self.target_elt.find('vlanif').text + """update the XML target_elt""" + old_vlanif = self.target_elt.find("vlanif").text (before, changed) = super(PFSenseVlanModule, self)._copy_and_update_target() if changed: - self.setup_vlan_cmds += "pfSense_interface_destroy('{0}');\n".format(old_vlanif) + self.setup_vlan_cmds += "pfSense_interface_destroy('{0}');\n".format( + old_vlanif + ) self.setup_vlan_cmds += self._cmd_create() return (before, changed) def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('vlan') + """create the XML target_elt""" + return self.pfsense.new_element("vlan") def _find_target(self): - """ find the XML target_elt """ - return self.pfsense.find_vlan(self.obj['if'], self.obj['tag']) + """find the XML target_elt""" + return self.pfsense.find_vlan(self.obj["if"], self.obj["tag"]) def _pre_remove_target_elt(self): - """ processing before removing elt """ - if self.pfsense.get_interface_by_port('{0}.{1}'.format(self.obj['if'], self.obj['tag'])) is not None: + """processing before removing elt""" + if ( + self.pfsense.get_interface_by_port( + "{0}.{1}".format(self.obj["if"], self.obj["tag"]) + ) + is not None + ): self.module.fail_json( - msg='vlan {0} on {1} cannot be deleted because it is still being used as an interface'.format(self.obj['tag'], self.obj['if']) + msg="vlan {0} on {1} cannot be deleted because it is still being used as an interface".format( + self.obj["tag"], self.obj["if"] + ) ) - self.setup_vlan_cmds += "pfSense_interface_destroy('{0}');\n".format(self.target_elt.find('vlanif').text) + self.setup_vlan_cmds += "pfSense_interface_destroy('{0}');\n".format( + self.target_elt.find("vlanif").text + ) ############################## # run # def get_update_cmds(self): - """ build and return php commands to setup interfaces """ + """build and return php commands to setup interfaces""" cmd = 'require_once("filter.inc");\n' if self.setup_vlan_cmds != "": cmd += 'require_once("interfaces.inc");\n' @@ -192,23 +232,27 @@ def get_update_cmds(self): return cmd def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" return self.pfsense.phpshell(self.get_update_cmds()) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}.{1}'".format(self.obj['if'], self.obj['tag']) + """return obj's name""" + return "'{0}.{1}'".format(self.obj["if"], self.obj["tag"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.obj, 'descr') - values += self.format_cli_field(self.obj, 'pcp', fname='priority') + values += self.format_cli_field(self.obj, "descr") + values += self.format_cli_field(self.obj, "pcp", fname="priority") else: - values += self.format_updated_cli_field(self.obj, before, 'pcp', add_comma=(values), fname='priority') - values += self.format_updated_cli_field(self.obj, before, 'descr', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "pcp", add_comma=(values), fname="priority" + ) + values += self.format_updated_cli_field( + self.obj, before, "descr", add_comma=(values) + ) return values diff --git a/plugins/modules/pfsense_aggregate.py b/plugins/modules/pfsense_aggregate.py index ce43e4cc..4a53c097 100644 --- a/plugins/modules/pfsense_aggregate.py +++ b/plugins/modules/pfsense_aggregate.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -624,34 +627,52 @@ sample: ["create vlan 'mvneta.100', descr='voice', priority='5'", "update vlan 'mvneta.100', set priority='6'", "delete vlan 'mvneta.100'"] """ -from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import PFSenseModule -from ansible_collections.pfsensible.core.plugins.module_utils.alias import PFSenseAliasModule, ALIAS_ARGUMENT_SPEC, ALIAS_MUTUALLY_EXCLUSIVE, ALIAS_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import ( + PFSenseModule, +) +from ansible_collections.pfsensible.core.plugins.module_utils.alias import ( + PFSenseAliasModule, + ALIAS_ARGUMENT_SPEC, + ALIAS_MUTUALLY_EXCLUSIVE, + ALIAS_REQUIRED_IF, +) from ansible_collections.pfsensible.core.plugins.module_utils.interface import ( PFSenseInterfaceModule, INTERFACE_ARGUMENT_SPEC, INTERFACE_REQUIRED_IF, INTERFACE_MUTUALLY_EXCLUSIVE, ) -from ansible_collections.pfsensible.core.plugins.module_utils.nat_outbound import PFSenseNatOutboundModule, NAT_OUTBOUND_ARGUMENT_SPEC, NAT_OUTBOUND_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.nat_outbound import ( + PFSenseNatOutboundModule, + NAT_OUTBOUND_ARGUMENT_SPEC, + NAT_OUTBOUND_REQUIRED_IF, +) from ansible_collections.pfsensible.core.plugins.module_utils.nat_port_forward import ( PFSenseNatPortForwardModule, NAT_PORT_FORWARD_ARGUMENT_SPEC, - NAT_PORT_FORWARD_REQUIRED_IF + NAT_PORT_FORWARD_REQUIRED_IF, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule import ( + PFSenseRuleModule, + RULE_ARGUMENT_SPEC, + RULE_REQUIRED_IF, ) -from ansible_collections.pfsensible.core.plugins.module_utils.rule import PFSenseRuleModule, RULE_ARGUMENT_SPEC, RULE_REQUIRED_IF from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import ( PFSenseRuleSeparatorModule, RULE_SEPARATOR_ARGUMENT_SPEC, RULE_SEPARATOR_REQUIRED_ONE_OF, RULE_SEPARATOR_MUTUALLY_EXCLUSIVE, ) -from ansible_collections.pfsensible.core.plugins.module_utils.vlan import PFSenseVlanModule, VLAN_ARGUMENT_SPEC +from ansible_collections.pfsensible.core.plugins.module_utils.vlan import ( + PFSenseVlanModule, + VLAN_ARGUMENT_SPEC, +) from ansible.module_utils.basic import AnsibleModule class PFSenseModuleAggregate(object): - """ module managing pfsense aggregated aliases, rules, rule separators, interfaces and VLANs """ + """module managing pfsense aggregated aliases, rules, rule separators, interfaces and VLANs""" def __init__(self, module): self.module = module @@ -659,7 +680,9 @@ def __init__(self, module): self.pfsense_aliases = PFSenseAliasModule(module, self.pfsense) self.pfsense_interfaces = PFSenseInterfaceModule(module, self.pfsense) self.pfsense_nat_outbounds = PFSenseNatOutboundModule(module, self.pfsense) - self.pfsense_nat_port_forwards = PFSenseNatPortForwardModule(module, self.pfsense) + self.pfsense_nat_port_forwards = PFSenseNatPortForwardModule( + module, self.pfsense + ) self.pfsense_rules = PFSenseRuleModule(module, self.pfsense) self.pfsense_rule_separators = PFSenseRuleSeparatorModule(module, self.pfsense) self.pfsense_vlans = PFSenseVlanModule(module, self.pfsense) @@ -669,54 +692,61 @@ def _update(self): run = False cmd = 'require_once("filter.inc");\n' # TODO: manage one global list of commands as ordering can be important between modules - if self.pfsense_vlans.result['changed']: + if self.pfsense_vlans.result["changed"]: run = True cmd += self.pfsense_vlans.get_update_cmds() - if self.pfsense_interfaces.result['changed']: + if self.pfsense_interfaces.result["changed"]: run = True cmd += self.pfsense_interfaces.get_update_cmds() - cmd += 'if (filter_configure() == 0) { \n' - if self.pfsense_aliases.result['changed']: + cmd += "if (filter_configure() == 0) { \n" + if self.pfsense_aliases.result["changed"]: run = True - cmd += 'clear_subsystem_dirty(\'aliases\');\n' + cmd += "clear_subsystem_dirty('aliases');\n" - if self.pfsense_nat_port_forwards.result['changed'] or self.pfsense_nat_outbounds.result['changed']: + if ( + self.pfsense_nat_port_forwards.result["changed"] + or self.pfsense_nat_outbounds.result["changed"] + ): run = True - cmd += 'clear_subsystem_dirty(\'natconf\');\n' - - if (self.pfsense_rules.result['changed'] or self.pfsense_rule_separators.result['changed'] or - self.pfsense_nat_port_forwards.result['changed'] or self.pfsense_nat_outbounds.result['changed']): + cmd += "clear_subsystem_dirty('natconf');\n" + + if ( + self.pfsense_rules.result["changed"] + or self.pfsense_rule_separators.result["changed"] + or self.pfsense_nat_port_forwards.result["changed"] + or self.pfsense_nat_outbounds.result["changed"] + ): run = True - cmd += 'clear_subsystem_dirty(\'filter\');\n' - cmd += '}' + cmd += "clear_subsystem_dirty('filter');\n" + cmd += "}" if run: return self.pfsense.phpshell(cmd) - return ('', '', '') + return ("", "", "") def _parse_floating_interfaces(self, interfaces): - """ parse interfaces """ + """parse interfaces""" res = set() - for interface in interfaces.split(','): + for interface in interfaces.split(","): res.add(self.pfsense.parse_interface(interface)) return res - def want_rule(self, rule_elt, rules, name_field='name'): - """ return True if we want to keep rule_elt """ - descr = rule_elt.find('descr') - interface = rule_elt.find('interface') - floating = rule_elt.find('floating') is not None + def want_rule(self, rule_elt, rules, name_field="name"): + """return True if we want to keep rule_elt""" + descr = rule_elt.find("descr") + interface = rule_elt.find("interface") + floating = rule_elt.find("floating") is not None # probably not a rule if descr is None or interface is None: return True - if descr.text in self.module.params['ignored_rules']: + if descr.text in self.module.params["ignored_rules"]: return True - key = '{0}_{1}'.format(interface.text, floating) + key = "{0}_{1}".format(interface.text, floating) if key not in self.defined_rules: defined_rules = set() self.defined_rules[key] = defined_rules @@ -727,84 +757,89 @@ def want_rule(self, rule_elt, rules, name_field='name'): return False for rule in rules: - if rule['state'] == 'absent': + if rule["state"] == "absent": continue if rule[name_field] != descr.text: continue - rule_floating = (rule.get('floating') is not None and - (isinstance(rule['floating'], bool) and - rule['floating'] or rule['floating'].lower() in ['yes', 'true'])) + rule_floating = rule.get("floating") is not None and ( + isinstance(rule["floating"], bool) + and rule["floating"] + or rule["floating"].lower() in ["yes", "true"] + ) if floating != rule_floating: continue - if floating or self.pfsense.parse_interface(rule['interface']) == interface.text: + if ( + floating + or self.pfsense.parse_interface(rule["interface"]) == interface.text + ): defined_rules.add(descr.text) return True return False def want_rule_separator(self, separator_elt, rule_separators): - """ return True if we want to keep separator_elt """ - name = separator_elt.find('text').text - interface = separator_elt.find('if').text + """return True if we want to keep separator_elt""" + name = separator_elt.find("text").text + interface = separator_elt.find("if").text for separator in rule_separators: - if separator['state'] == 'absent': + if separator["state"] == "absent": continue - if separator['name'] != name: + if separator["name"] != name: continue - if separator.get('floating'): - if interface == 'floatingrules': + if separator.get("floating"): + if interface == "floatingrules": return True - elif self.pfsense.parse_interface(separator['interface']) == interface: + elif self.pfsense.parse_interface(separator["interface"]) == interface: return True return False def want_alias(self, alias_elt, aliases): - """ return True if we want to keep alias_elt """ - name = alias_elt.find('name') - alias_type = alias_elt.find('type') + """return True if we want to keep alias_elt""" + name = alias_elt.find("name") + alias_type = alias_elt.find("type") # probably not an alias if name is None or type is None: return True - if name.text in self.module.params['ignored_aliases']: + if name.text in self.module.params["ignored_aliases"]: return True for alias in aliases: - if alias['state'] == 'absent': + if alias["state"] == "absent": continue - if alias['name'] == name.text and alias['type'] == alias_type.text: + if alias["name"] == name.text and alias["type"] == alias_type.text: return True return False @staticmethod def want_interface(interface_elt, interfaces): - """ return True if we want to keep interface_elt """ - descr_elt = interface_elt.find('descr') + """return True if we want to keep interface_elt""" + descr_elt = interface_elt.find("descr") if descr_elt is not None and descr_elt.text: name = descr_elt.text else: name = interface_elt.tag for interface in interfaces: - if interface['state'] == 'absent': + if interface["state"] == "absent": continue - if interface['descr'] == name: + if interface["descr"] == name: return True return False @staticmethod def want_vlan(vlan_elt, vlans): - """ return True if we want to keep vlan_elt """ - tag = int(vlan_elt.find('tag').text) - interface = vlan_elt.find('if') + """return True if we want to keep vlan_elt""" + tag = int(vlan_elt.find("tag").text) + interface = vlan_elt.find("if") for vlan in vlans: - if vlan['state'] == 'absent': + if vlan["state"] == "absent": continue - if vlan['vlan_id'] == tag and vlan['interface'] == interface.text: + if vlan["vlan_id"] == tag and vlan["interface"] == interface.text: return True return False @@ -813,50 +848,58 @@ def is_filtered(interface_filter, params): if interface_filter is None: return False - if 'floating' in params: - if isinstance(params['floating'], str): - floating = params['floating'].lower() + if "floating" in params: + if isinstance(params["floating"], str): + floating = params["floating"].lower() else: - floating = 'true' if params['floating'] else 'false' + floating = "true" if params["floating"] else "false" - if floating != 'false' and floating != 'no': - return 'floating' not in interface_filter + if floating != "false" and floating != "no": + return "floating" not in interface_filter - return params['interface'].lower() not in interface_filter + return params["interface"].lower() not in interface_filter def run_rules(self): - """ process input params to add/update/delete all rules """ + """process input params to add/update/delete all rules""" - want = self.module.params['aggregated_rules'] - interface_filter = self.module.params['interface_filter'].lower().split(' ') if self.module.params.get('interface_filter') is not None else None + want = self.module.params["aggregated_rules"] + interface_filter = ( + self.module.params["interface_filter"].lower().split(" ") + if self.module.params.get("interface_filter") is not None + else None + ) if want is None: return # delete every other rule if required - if self.module.params['purge_rules']: + if self.module.params["purge_rules"]: todel = [] for rule_elt in self.pfsense_rules.root_elt: if not self.want_rule(rule_elt, want): params = {} - params['state'] = 'absent' - params['name'] = rule_elt.find('descr').text + params["state"] = "absent" + params["name"] = rule_elt.find("descr").text - if rule_elt.find('floating') is not None: - params['floating'] = True - interfaces = rule_elt.find('interface').text.split(',') - params['interface'] = list() + if rule_elt.find("floating") is not None: + params["floating"] = True + interfaces = rule_elt.find("interface").text.split(",") + params["interface"] = list() for interface in interfaces: - target = self.pfsense.get_interface_display_name(interface, return_none=True) + target = self.pfsense.get_interface_display_name( + interface, return_none=True + ) if target is not None: - params['interface'].append(target) + params["interface"].append(target) else: - params['interface'].append(interface) - params['interface'] = ','.join(params['interface']) + params["interface"].append(interface) + params["interface"] = ",".join(params["interface"]) else: - params['interface'] = self.pfsense.get_interface_display_name(rule_elt.find('interface').text, return_none=True) + params["interface"] = self.pfsense.get_interface_display_name( + rule_elt.find("interface").text, return_none=True + ) - if params['interface'] is None: + if params["interface"] is None: continue todel.append(params) @@ -867,28 +910,32 @@ def run_rules(self): self.pfsense_rules.run(params) # generating order if required - if self.module.params.get('order_rules'): + if self.module.params.get("order_rules"): last_rules = dict() for params in want: - if params.get('before') is not None or params.get('after') is not None: - self.module.fail_json(msg="You can't use after or before parameters on rules when using order_rules (see {0})".format(params['name'])) - - if params.get('state') == 'absent': + if params.get("before") is not None or params.get("after") is not None: + self.module.fail_json( + msg="You can't use after or before parameters on rules when using order_rules (see {0})".format( + params["name"] + ) + ) + + if params.get("state") == "absent": continue - if params.get('floating'): - key = 'floating' + if params.get("floating"): + key = "floating" else: - key = params['interface'] + key = params["interface"] # first rule on interface if key not in last_rules: - params['after'] = 'top' - last_rules[key] = params['name'] + params["after"] = "top" + last_rules[key] = params["name"] continue - params['after'] = last_rules[key] - last_rules[key] = params['name'] + params["after"] = last_rules[key] + last_rules[key] = params["name"] # processing aggregated parameters for params in want: @@ -897,25 +944,31 @@ def run_rules(self): self.pfsense_rules.run(params) def run_nat_outbounds_rules(self): - """ process input params to add/update/delete all nat_outbound rules """ + """process input params to add/update/delete all nat_outbound rules""" - want = self.module.params['aggregated_nat_outbounds'] - interface_filter = self.module.params['interface_filter'].lower().split(' ') if self.module.params.get('interface_filter') is not None else None + want = self.module.params["aggregated_nat_outbounds"] + interface_filter = ( + self.module.params["interface_filter"].lower().split(" ") + if self.module.params.get("interface_filter") is not None + else None + ) if want is None: return # delete every other rule if required - if self.module.params['purge_nat_outbounds']: + if self.module.params["purge_nat_outbounds"]: todel = [] for rule_elt in self.pfsense_nat_outbounds.root_elt: - if not self.want_rule(rule_elt, want, name_field='descr'): + if not self.want_rule(rule_elt, want, name_field="descr"): params = {} - params['state'] = 'absent' - params['descr'] = rule_elt.find('descr').text - params['interface'] = self.pfsense.get_interface_display_name(rule_elt.find('interface').text, return_none=True) + params["state"] = "absent" + params["descr"] = rule_elt.find("descr").text + params["interface"] = self.pfsense.get_interface_display_name( + rule_elt.find("interface").text, return_none=True + ) - if params['interface'] is None: + if params["interface"] is None: continue todel.append(params) @@ -932,25 +985,31 @@ def run_nat_outbounds_rules(self): self.pfsense_nat_outbounds.run(params) def run_nat_port_forwards_rules(self): - """ process input params to add/update/delete all nat_port_forwards_rule rules """ + """process input params to add/update/delete all nat_port_forwards_rule rules""" - want = self.module.params['aggregated_nat_port_forwards'] - interface_filter = self.module.params['interface_filter'].lower().split(' ') if self.module.params.get('interface_filter') is not None else None + want = self.module.params["aggregated_nat_port_forwards"] + interface_filter = ( + self.module.params["interface_filter"].lower().split(" ") + if self.module.params.get("interface_filter") is not None + else None + ) if want is None: return # delete every other rule if required - if self.module.params['purge_nat_port_forwards']: + if self.module.params["purge_nat_port_forwards"]: todel = [] for rule_elt in self.pfsense_nat_port_forwards.root_elt: - if not self.want_rule(rule_elt, want, name_field='descr'): + if not self.want_rule(rule_elt, want, name_field="descr"): params = {} - params['state'] = 'absent' - params['descr'] = rule_elt.find('descr').text - params['interface'] = self.pfsense.get_interface_display_name(rule_elt.find('interface').text, return_none=True) + params["state"] = "absent" + params["descr"] = rule_elt.find("descr").text + params["interface"] = self.pfsense.get_interface_display_name( + rule_elt.find("interface").text, return_none=True + ) - if params['interface'] is None: + if params["interface"] is None: continue todel.append(params) @@ -967,8 +1026,8 @@ def run_nat_port_forwards_rules(self): self.pfsense_nat_port_forwards.run(params) def run_aliases(self): - """ process input params to add/update/delete all aliases """ - want = self.module.params['aggregated_aliases'] + """process input params to add/update/delete all aliases""" + want = self.module.params["aggregated_aliases"] if want is None: return @@ -978,21 +1037,21 @@ def run_aliases(self): self.pfsense_aliases.run(param) # delete every other alias if required - if self.module.params['purge_aliases']: + if self.module.params["purge_aliases"]: todel = [] for alias_elt in self.pfsense_aliases.root_elt: if not self.want_alias(alias_elt, want): params = {} - params['state'] = 'absent' - params['name'] = alias_elt.find('name').text + params["state"] = "absent" + params["name"] = alias_elt.find("name").text todel.append(params) for params in todel: self.pfsense_aliases.run(params) def run_interfaces(self): - """ process input params to add/update/delete all interfaces """ - want = self.module.params['aggregated_interfaces'] + """process input params to add/update/delete all interfaces""" + want = self.module.params["aggregated_interfaces"] if want is None: return @@ -1002,24 +1061,28 @@ def run_interfaces(self): self.pfsense_interfaces.run(param) # delete every other if required - if self.module.params['purge_interfaces']: + if self.module.params["purge_interfaces"]: todel = [] for interface_elt in self.pfsense_interfaces.root_elt: if not self.want_interface(interface_elt, want): params = {} - params['state'] = 'absent' - descr_elt = interface_elt.find('descr') + params["state"] = "absent" + descr_elt = interface_elt.find("descr") if descr_elt is not None and descr_elt.text: - params['descr'] = descr_elt.text + params["descr"] = descr_elt.text todel.append(params) for params in todel: self.pfsense_interfaces.run(params) def run_rule_separators(self): - """ process input params to add/update/delete all separators """ - want = self.module.params['aggregated_rule_separators'] - interface_filter = self.module.params['interface_filter'].lower().split(' ') if self.module.params.get('interface_filter') is not None else None + """process input params to add/update/delete all separators""" + want = self.module.params["aggregated_rule_separators"] + interface_filter = ( + self.module.params["interface_filter"].lower().split(" ") + if self.module.params.get("interface_filter") is not None + else None + ) if want is None: return @@ -1031,19 +1094,23 @@ def run_rule_separators(self): self.pfsense_rule_separators.run(params) # delete every other if required - if self.module.params['purge_rule_separators']: + if self.module.params["purge_rule_separators"]: todel = [] for interface_elt in self.pfsense_rule_separators.separators: for separator_elt in interface_elt: if not self.want_rule_separator(separator_elt, want): params = {} - params['state'] = 'absent' - params['name'] = separator_elt.find('text').text - if interface_elt.tag == 'floatingrules': - params['floating'] = True + params["state"] = "absent" + params["name"] = separator_elt.find("text").text + if interface_elt.tag == "floatingrules": + params["floating"] = True else: - params['interface'] = self.pfsense.get_interface_display_name(interface_elt.tag, return_none=True) - if params['interface'] is None: + params["interface"] = ( + self.pfsense.get_interface_display_name( + interface_elt.tag, return_none=True + ) + ) + if params["interface"] is None: continue todel.append(params) @@ -1053,8 +1120,8 @@ def run_rule_separators(self): self.pfsense_rule_separators.run(params) def run_vlans(self): - """ process input params to add/update/delete all VLANs """ - want = self.module.params['aggregated_vlans'] + """process input params to add/update/delete all VLANs""" + want = self.module.params["aggregated_vlans"] if want is None: return @@ -1064,88 +1131,127 @@ def run_vlans(self): self.pfsense_vlans.run(param) # delete every other if required - if self.module.params['purge_vlans']: + if self.module.params["purge_vlans"]: todel = [] for vlan_elt in self.pfsense_vlans.root_elt: if not self.want_vlan(vlan_elt, want): params = {} - params['state'] = 'absent' - params['interface'] = vlan_elt.find('if').text - params['vlan_id'] = int(vlan_elt.find('tag').text) + params["state"] = "absent" + params["interface"] = vlan_elt.find("if").text + params["vlan_id"] = int(vlan_elt.find("tag").text) todel.append(params) for params in todel: self.pfsense_vlans.run(params) def commit_changes(self): - """ apply changes and exit module """ - stdout = '' - stderr = '' + """apply changes and exit module""" + stdout = "" + stderr = "" changed = ( - self.pfsense_aliases.result['changed'] or self.pfsense_interfaces.result['changed'] or self.pfsense_nat_outbounds.result['changed'] - or self.pfsense_nat_port_forwards.result['changed'] or self.pfsense_rules.result['changed'] - or self.pfsense_rule_separators.result['changed'] or self.pfsense_vlans.result['changed'] + self.pfsense_aliases.result["changed"] + or self.pfsense_interfaces.result["changed"] + or self.pfsense_nat_outbounds.result["changed"] + or self.pfsense_nat_port_forwards.result["changed"] + or self.pfsense_rules.result["changed"] + or self.pfsense_rule_separators.result["changed"] + or self.pfsense_vlans.result["changed"] ) if changed and not self.module.check_mode: - self.pfsense.write_config(descr='aggregated change') + self.pfsense.write_config(descr="aggregated change") (dummy, stdout, stderr) = self._update() result = {} - result['result_aliases'] = self.pfsense_aliases.result['commands'] - result['result_interfaces'] = self.pfsense_interfaces.result['commands'] - result['result_nat_outbounds'] = self.pfsense_nat_outbounds.result['commands'] - result['result_nat_port_forwards'] = self.pfsense_nat_port_forwards.result['commands'] - result['result_rules'] = self.pfsense_rules.result['commands'] - result['result_rule_separators'] = self.pfsense_rule_separators.result['commands'] - result['result_vlans'] = self.pfsense_vlans.result['commands'] - result['changed'] = changed - result['stdout'] = stdout - result['stderr'] = stderr + result["result_aliases"] = self.pfsense_aliases.result["commands"] + result["result_interfaces"] = self.pfsense_interfaces.result["commands"] + result["result_nat_outbounds"] = self.pfsense_nat_outbounds.result["commands"] + result["result_nat_port_forwards"] = self.pfsense_nat_port_forwards.result[ + "commands" + ] + result["result_rules"] = self.pfsense_rules.result["commands"] + result["result_rule_separators"] = self.pfsense_rule_separators.result[ + "commands" + ] + result["result_vlans"] = self.pfsense_vlans.result["commands"] + result["changed"] = changed + result["stdout"] = stdout + result["stderr"] = stderr self.module.exit_json(**result) def main(): argument_spec = dict( aggregated_aliases=dict( - type='list', elements='dict', options=ALIAS_ARGUMENT_SPEC, mutually_exclusive=ALIAS_MUTUALLY_EXCLUSIVE, required_if=ALIAS_REQUIRED_IF), + type="list", + elements="dict", + options=ALIAS_ARGUMENT_SPEC, + mutually_exclusive=ALIAS_MUTUALLY_EXCLUSIVE, + required_if=ALIAS_REQUIRED_IF, + ), aggregated_interfaces=dict( - type='list', elements='dict', - options=INTERFACE_ARGUMENT_SPEC, required_if=INTERFACE_REQUIRED_IF, mutually_exclusive=INTERFACE_MUTUALLY_EXCLUSIVE), - aggregated_rules=dict(type='list', elements='dict', options=RULE_ARGUMENT_SPEC, required_if=RULE_REQUIRED_IF), - aggregated_nat_outbounds=dict(type='list', elements='dict', options=NAT_OUTBOUND_ARGUMENT_SPEC, required_if=NAT_OUTBOUND_REQUIRED_IF), - aggregated_nat_port_forwards=dict(type='list', elements='dict', options=NAT_PORT_FORWARD_ARGUMENT_SPEC, required_if=NAT_PORT_FORWARD_REQUIRED_IF), + type="list", + elements="dict", + options=INTERFACE_ARGUMENT_SPEC, + required_if=INTERFACE_REQUIRED_IF, + mutually_exclusive=INTERFACE_MUTUALLY_EXCLUSIVE, + ), + aggregated_rules=dict( + type="list", + elements="dict", + options=RULE_ARGUMENT_SPEC, + required_if=RULE_REQUIRED_IF, + ), + aggregated_nat_outbounds=dict( + type="list", + elements="dict", + options=NAT_OUTBOUND_ARGUMENT_SPEC, + required_if=NAT_OUTBOUND_REQUIRED_IF, + ), + aggregated_nat_port_forwards=dict( + type="list", + elements="dict", + options=NAT_PORT_FORWARD_ARGUMENT_SPEC, + required_if=NAT_PORT_FORWARD_REQUIRED_IF, + ), aggregated_rule_separators=dict( - type='list', elements='dict', - options=RULE_SEPARATOR_ARGUMENT_SPEC, required_one_of=RULE_SEPARATOR_REQUIRED_ONE_OF, mutually_exclusive=RULE_SEPARATOR_MUTUALLY_EXCLUSIVE), - aggregated_vlans=dict(type='list', elements='dict', options=VLAN_ARGUMENT_SPEC), - order_rules=dict(default=False, type='bool'), - purge_aliases=dict(default=False, type='bool'), - purge_interfaces=dict(default=False, type='bool'), - purge_nat_outbounds=dict(default=False, type='bool'), - purge_nat_port_forwards=dict(default=False, type='bool'), - purge_rules=dict(default=False, type='bool'), - purge_rule_separators=dict(default=False, type='bool'), - purge_vlans=dict(default=False, type='bool'), - interface_filter=dict(required=False, type='str'), - ignored_aliases=dict(type='list', elements='str', default=[]), - ignored_rules=dict(type='list', elements='str', default=[]), + type="list", + elements="dict", + options=RULE_SEPARATOR_ARGUMENT_SPEC, + required_one_of=RULE_SEPARATOR_REQUIRED_ONE_OF, + mutually_exclusive=RULE_SEPARATOR_MUTUALLY_EXCLUSIVE, + ), + aggregated_vlans=dict(type="list", elements="dict", options=VLAN_ARGUMENT_SPEC), + order_rules=dict(default=False, type="bool"), + purge_aliases=dict(default=False, type="bool"), + purge_interfaces=dict(default=False, type="bool"), + purge_nat_outbounds=dict(default=False, type="bool"), + purge_nat_port_forwards=dict(default=False, type="bool"), + purge_rules=dict(default=False, type="bool"), + purge_rule_separators=dict(default=False, type="bool"), + purge_vlans=dict(default=False, type="bool"), + interface_filter=dict(required=False, type="str"), + ignored_aliases=dict(type="list", elements="str", default=[]), + ignored_rules=dict(type="list", elements="str", default=[]), ) - required_one_of = [[ - 'aggregated_aliases', - 'aggregated_interfaces', - 'aggregated_nat_outbounds', - 'aggregated_nat_port_forwards', - 'aggregated_rules', - 'aggregated_rule_separators', - 'aggregated_vlans' - ]] + required_one_of = [ + [ + "aggregated_aliases", + "aggregated_interfaces", + "aggregated_nat_outbounds", + "aggregated_nat_port_forwards", + "aggregated_rules", + "aggregated_rule_separators", + "aggregated_vlans", + ] + ] module = AnsibleModule( argument_spec=argument_spec, required_one_of=required_one_of, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseModuleAggregate(module) @@ -1161,5 +1267,5 @@ def main(): pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_alias.py b/plugins/modules/pfsense_alias.py index 5c6306a5..3dba3849 100644 --- a/plugins/modules/pfsense_alias.py +++ b/plugins/modules/pfsense_alias.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -84,7 +87,12 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.alias import PFSenseAliasModule, ALIAS_ARGUMENT_SPEC, ALIAS_MUTUALLY_EXCLUSIVE, ALIAS_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.alias import ( + PFSenseAliasModule, + ALIAS_ARGUMENT_SPEC, + ALIAS_MUTUALLY_EXCLUSIVE, + ALIAS_REQUIRED_IF, +) def main(): @@ -92,12 +100,13 @@ def main(): argument_spec=ALIAS_ARGUMENT_SPEC, mutually_exclusive=ALIAS_MUTUALLY_EXCLUSIVE, required_if=ALIAS_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseAliasModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_authserver_ldap.py b/plugins/modules/pfsense_authserver_ldap.py index cd6ffed3..1299c0e0 100644 --- a/plugins/modules/pfsense_authserver_ldap.py +++ b/plugins/modules/pfsense_authserver_ldap.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -145,49 +148,39 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) PFSENSE_AUTHSERVER_LDAP_SPEC = { - 'name': {'required': True, 'type': 'str'}, - 'state': { - 'default': 'present', - 'choices': ['present', 'absent'] - }, - 'host': {'type': 'str'}, - 'port': {'default': '389', 'type': 'str'}, - 'transport': { - 'choices': ['tcp', 'starttls', 'ssl'] - }, - 'ca': {'default': 'global', 'type': 'str'}, - 'protver': { - 'default': '3', - 'choices': ['2', '3'] - }, - 'timeout': {'default': '25', 'type': 'str'}, - 'scope': { - 'choices': ['one', 'subtree'] - }, - 'basedn': {'required': False, 'type': 'str'}, - 'authcn': {'required': False, 'type': 'str'}, - 'extended_enabled': {'default': False, 'type': 'bool'}, - 'extended_query': {'default': '', 'type': 'str'}, - 'binddn': {'required': False, 'type': 'str'}, - 'bindpw': {'required': False, 'type': 'str'}, - 'attr_user': {'default': 'cn', 'type': 'str'}, - 'attr_group': {'default': 'cn', 'type': 'str'}, - 'attr_member': {'default': 'member', 'type': 'str'}, - 'attr_groupobj': {'default': 'posixGroup', 'type': 'str'}, - 'ldap_pam_groupdn': {'required': False, 'type': 'str'}, - 'ldap_utf8': {'required': False, 'type': 'bool'}, - 'ldap_nostrip_at': {'required': False, 'type': 'bool'}, - 'ldap_rfc2307': {'required': False, 'type': 'bool'}, - 'ldap_rfc2307_userdn': {'required': False, 'type': 'bool'}, - 'ldap_allow_unauthenticated': {'required': False, 'type': 'bool'}, + "name": {"required": True, "type": "str"}, + "state": {"default": "present", "choices": ["present", "absent"]}, + "host": {"type": "str"}, + "port": {"default": "389", "type": "str"}, + "transport": {"choices": ["tcp", "starttls", "ssl"]}, + "ca": {"default": "global", "type": "str"}, + "protver": {"default": "3", "choices": ["2", "3"]}, + "timeout": {"default": "25", "type": "str"}, + "scope": {"choices": ["one", "subtree"]}, + "basedn": {"required": False, "type": "str"}, + "authcn": {"required": False, "type": "str"}, + "extended_enabled": {"default": False, "type": "bool"}, + "extended_query": {"default": "", "type": "str"}, + "binddn": {"required": False, "type": "str"}, + "bindpw": {"required": False, "type": "str"}, + "attr_user": {"default": "cn", "type": "str"}, + "attr_group": {"default": "cn", "type": "str"}, + "attr_member": {"default": "member", "type": "str"}, + "attr_groupobj": {"default": "posixGroup", "type": "str"}, + "ldap_pam_groupdn": {"required": False, "type": "str"}, + "ldap_utf8": {"required": False, "type": "bool"}, + "ldap_nostrip_at": {"required": False, "type": "bool"}, + "ldap_rfc2307": {"required": False, "type": "bool"}, + "ldap_rfc2307_userdn": {"required": False, "type": "bool"}, + "ldap_allow_unauthenticated": {"required": False, "type": "bool"}, } -AUTHSERVER_LDAP_CREATE_DEFAULT = dict( - ldap_allow_unauthenticated=None -) +AUTHSERVER_LDAP_CREATE_DEFAULT = dict(ldap_allow_unauthenticated=None) AUTHSERVER_LDAP_PHP_COMMAND = """ require_once('auth.inc'); @@ -199,7 +192,7 @@ class PFSenseAuthserverLDAPModule(PFSenseModuleBase): - """ module managing pfsense LDAP authentication """ + """module managing pfsense LDAP authentication""" ############################## # unit tests @@ -207,82 +200,108 @@ class PFSenseAuthserverLDAPModule(PFSenseModuleBase): # Must be class method for unit test usage @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return PFSENSE_AUTHSERVER_LDAP_SPEC def __init__(self, module, pfsense=None): - super(PFSenseAuthserverLDAPModule, self).__init__(module, pfsense, name='pfsense_authserver_ldap', root='system', node='authserver', key='name', - bool_style='absent/present', have_refid=True, create_default=AUTHSERVER_LDAP_CREATE_DEFAULT) + super(PFSenseAuthserverLDAPModule, self).__init__( + module, + pfsense, + name="pfsense_authserver_ldap", + root="system", + node="authserver", + key="name", + bool_style="absent/present", + have_refid=True, + create_default=AUTHSERVER_LDAP_CREATE_DEFAULT, + ) ############################## # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" - if int(self.params['timeout']) < 1: - self.module.fail_json(msg='timeout {0} must be greater than 1'.format(self.params['timeout'])) + if int(self.params["timeout"]) < 1: + self.module.fail_json( + msg="timeout {0} must be greater than 1".format(self.params["timeout"]) + ) def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() - obj['name'] = params['name'] - if params['state'] == 'present': - obj['type'] = 'ldap' - for option in ['host']: + obj["name"] = params["name"] + if params["state"] == "present": + obj["type"] = "ldap" + for option in ["host"]: if option in params and params[option] is not None: obj[option] = params[option] - obj['ldap_port'] = params['port'] + obj["ldap_port"] = params["port"] if self.pfsense.config_version >= 20.1: - urltype = dict({'tcp': 'Standard TCP', 'starttls': 'STARTTLS Encrypted', 'ssl': 'SSL/TLS Encrypted'}) + urltype = dict( + { + "tcp": "Standard TCP", + "starttls": "STARTTLS Encrypted", + "ssl": "SSL/TLS Encrypted", + } + ) else: - urltype = dict({'tcp': 'TCP - Standard', 'starttls': 'TCP - STARTTLS', 'ssl': 'SSL - Encrypted'}) - obj['ldap_urltype'] = urltype[params['transport']] - obj['ldap_protver'] = params['protver'] - obj['ldap_timeout'] = params['timeout'] - obj['ldap_scope'] = params['scope'] - obj['ldap_basedn'] = params['basedn'] - obj['ldap_authcn'] = params['authcn'] - if params['extended_enabled']: - obj['ldap_extended_enabled'] = 'yes' + urltype = dict( + { + "tcp": "TCP - Standard", + "starttls": "TCP - STARTTLS", + "ssl": "SSL - Encrypted", + } + ) + obj["ldap_urltype"] = urltype[params["transport"]] + obj["ldap_protver"] = params["protver"] + obj["ldap_timeout"] = params["timeout"] + obj["ldap_scope"] = params["scope"] + obj["ldap_basedn"] = params["basedn"] + obj["ldap_authcn"] = params["authcn"] + if params["extended_enabled"]: + obj["ldap_extended_enabled"] = "yes" else: - obj['ldap_extended_enabled'] = '' - obj['ldap_extended_query'] = params['extended_query'] - if params['binddn']: - obj['ldap_binddn'] = params['binddn'] - if params['bindpw']: - obj['ldap_bindpw'] = params['bindpw'] - obj['ldap_attr_user'] = params['attr_user'] - obj['ldap_attr_group'] = params['attr_group'] - obj['ldap_attr_member'] = params['attr_member'] - obj['ldap_attr_groupobj'] = params['attr_groupobj'] - if params['ldap_utf8']: - obj['ldap_utf8'] = '' - if params['ldap_nostrip_at']: - obj['ldap_nostrip_at'] = '' - if params['ldap_rfc2307']: - obj['ldap_rfc2307'] = '' + obj["ldap_extended_enabled"] = "" + obj["ldap_extended_query"] = params["extended_query"] + if params["binddn"]: + obj["ldap_binddn"] = params["binddn"] + if params["bindpw"]: + obj["ldap_bindpw"] = params["bindpw"] + obj["ldap_attr_user"] = params["attr_user"] + obj["ldap_attr_group"] = params["attr_group"] + obj["ldap_attr_member"] = params["attr_member"] + obj["ldap_attr_groupobj"] = params["attr_groupobj"] + if params["ldap_utf8"]: + obj["ldap_utf8"] = "" + if params["ldap_nostrip_at"]: + obj["ldap_nostrip_at"] = "" + if params["ldap_rfc2307"]: + obj["ldap_rfc2307"] = "" if self.pfsense.is_at_least_2_5_0(): - obj['ldap_pam_groupdn'] = params['ldap_pam_groupdn'] - if params['ldap_rfc2307_userdn']: - obj['ldap_rfc2307_userdn'] = '' - if params['ldap_allow_unauthenticated']: - obj['ldap_allow_unauthenticated'] = '' + obj["ldap_pam_groupdn"] = params["ldap_pam_groupdn"] + if params["ldap_rfc2307_userdn"]: + obj["ldap_rfc2307_userdn"] = "" + if params["ldap_allow_unauthenticated"]: + obj["ldap_allow_unauthenticated"] = "" # Find the caref id for the named CA - obj['ldap_caref'] = self.pfsense.get_caref(params['ca']) + obj["ldap_caref"] = self.pfsense.get_caref(params["ca"]) # CA is required for SSL/TLS if self.pfsense.config_version >= 20.1: - if obj['ldap_caref'] is None and obj['ldap_urltype'] != 'Standard TCP': - self.module.fail_json(msg="Could not find CA '%s'" % (params['ca'])) + if obj["ldap_caref"] is None and obj["ldap_urltype"] != "Standard TCP": + self.module.fail_json(msg="Could not find CA '%s'" % (params["ca"])) else: - if obj['ldap_caref'] is None and obj['ldap_urltype'] != 'TCP - Standard': - self.module.fail_json(msg="Could not find CA '%s'" % (params['ca'])) + if ( + obj["ldap_caref"] is None + and obj["ldap_urltype"] != "TCP - Standard" + ): + self.module.fail_json(msg="Could not find CA '%s'" % (params["ca"])) return obj @@ -290,11 +309,17 @@ def _params_to_obj(self): # XML processing # def _find_target(self): - result = self.root_elt.findall("authserver[name='{0}'][type='ldap']".format(self.obj['name'])) + result = self.root_elt.findall( + "authserver[name='{0}'][type='ldap']".format(self.obj["name"]) + ) if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple ldap authentication servers for name {0}.'.format(self.obj['name'])) + self.module.fail_json( + msg="Found multiple ldap authentication servers for name {0}.".format( + self.obj["name"] + ) + ) else: return None @@ -302,8 +327,10 @@ def _find_target(self): # run # def _update(self): - """ update system configuration if needed """ - return self.pfsense.phpshell(AUTHSERVER_LDAP_PHP_COMMAND.format(name=self.obj['name'])) + """update system configuration if needed""" + return self.pfsense.phpshell( + AUTHSERVER_LDAP_PHP_COMMAND.format(name=self.obj["name"]) + ) def main(): @@ -312,12 +339,13 @@ def main(): required_if=[ ["state", "present", ["host", "port", "transport", "scope", "authcn"]], ], - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseAuthserverLDAPModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_authserver_radius.py b/plugins/modules/pfsense_authserver_radius.py index 6a3957ea..4878d7a7 100644 --- a/plugins/modules/pfsense_authserver_radius.py +++ b/plugins/modules/pfsense_authserver_radius.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -81,68 +84,74 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) AUTHSERVER_RADIUS_SPEC = { - 'name': {'required': True, 'type': 'str'}, - 'state': { - 'default': 'present', - 'choices': ['present', 'absent'], + "name": {"required": True, "type": "str"}, + "state": { + "default": "present", + "choices": ["present", "absent"], }, - 'host': {'type': 'str'}, - 'auth_port': {'default': '1812', 'type': 'int'}, - 'acct_port': {'default': '1813', 'type': 'int'}, - 'protocol': { - 'default': 'MSCHAPv2', - 'choices': ['PAP', 'CHAP_MD5', 'MSCHAPv1', 'MSCHAPv2'], + "host": {"type": "str"}, + "auth_port": {"default": "1812", "type": "int"}, + "acct_port": {"default": "1813", "type": "int"}, + "protocol": { + "default": "MSCHAPv2", + "choices": ["PAP", "CHAP_MD5", "MSCHAPv1", "MSCHAPv2"], }, - 'secret': {'type': 'str', 'no_log': True}, - 'timeout': {'default': '5', 'type': 'int'}, - 'nasip_attribute': {'default': 'lan', 'type': 'str'}, + "secret": {"type": "str", "no_log": True}, + "timeout": {"default": "5", "type": "int"}, + "nasip_attribute": {"default": "lan", "type": "str"}, } class PFSenseAuthserverRADIUSModule(PFSenseModuleBase): - """ module managing pfsense RADIUS authentication """ + """module managing pfsense RADIUS authentication""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return AUTHSERVER_RADIUS_SPEC def __init__(self, module, pfsense=None): super(PFSenseAuthserverRADIUSModule, self).__init__(module, pfsense) self.name = "pfsense_authserver_radius" - self.root_elt = self.pfsense.get_element('system') - self.authservers = self.root_elt.findall('authserver') + self.root_elt = self.pfsense.get_element("system") + self.authservers = self.root_elt.findall("authserver") ############################## # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" - if int(self.params['timeout']) < 1: - self.module.fail_json(msg='timeout {0} must be greater than 1'.format(self.params['timeout'])) + if int(self.params["timeout"]) < 1: + self.module.fail_json( + msg="timeout {0} must be greater than 1".format(self.params["timeout"]) + ) def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() self.obj = obj - obj['name'] = params['name'] - if params['state'] == 'present': - obj['type'] = 'radius' - self._get_ansible_param(obj, 'host') - self._get_ansible_param(obj, 'auth_port', fname='radius_auth_port') - self._get_ansible_param(obj, 'acct_port', fname='radius_acct_port') - self._get_ansible_param(obj, 'protocol', fname='radius_protocol') - self._get_ansible_param(obj, 'secret', fname='radius_secret') - self._get_ansible_param(obj, 'timeout', fname='radius_timeout') - self._get_ansible_param(obj, 'nasip_attribute', fname='radius_nasip_attribute') + obj["name"] = params["name"] + if params["state"] == "present": + obj["type"] = "radius" + self._get_ansible_param(obj, "host") + self._get_ansible_param(obj, "auth_port", fname="radius_auth_port") + self._get_ansible_param(obj, "acct_port", fname="radius_acct_port") + self._get_ansible_param(obj, "protocol", fname="radius_protocol") + self._get_ansible_param(obj, "secret", fname="radius_secret") + self._get_ansible_param(obj, "timeout", fname="radius_timeout") + self._get_ansible_param( + obj, "nasip_attribute", fname="radius_nasip_attribute" + ) return obj @@ -150,11 +159,17 @@ def _params_to_obj(self): # XML processing # def _find_target(self): - result = self.root_elt.findall("authserver[name='{0}'][type='radius']".format(self.obj['name'])) + result = self.root_elt.findall( + "authserver[name='{0}'][type='radius']".format(self.obj["name"]) + ) if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple radius authentication servers for name {0}.'.format(self.obj['name'])) + self.module.fail_json( + msg="Found multiple radius authentication servers for name {0}.".format( + self.obj["name"] + ) + ) else: return None @@ -162,17 +177,20 @@ def _find_this_index(self): return self.authservers.index(self.target_elt) def _create_target(self): - """ create the XML target_elt """ - elt = self.pfsense.new_element('authserver') - elt.append(self.pfsense.new_element('refid', text=self.pfsense.uniqid())) + """create the XML target_elt""" + elt = self.pfsense.new_element("authserver") + elt.append(self.pfsense.new_element("refid", text=self.pfsense.uniqid())) return elt def _copy_and_add_target(self): - """ populate the XML target_elt """ + """populate the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.obj + self.diff["after"] = self.obj if len(self.authservers) > 0: - self.root_elt.insert(list(self.root_elt).index(self.authservers[len(self.authservers) - 1]), self.target_elt) + self.root_elt.insert( + list(self.root_elt).index(self.authservers[len(self.authservers) - 1]), + self.target_elt, + ) else: self.root_elt.append(self.target_elt) @@ -180,12 +198,12 @@ def _copy_and_add_target(self): # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'{0}'".format(self.obj['name']) + """return obj's name""" + return "'{0}'".format(self.obj["name"]) def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" return values @@ -195,12 +213,13 @@ def main(): required_if=[ ["state", "present", ["host", "secret"]], ], - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseAuthserverRADIUSModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_ca.py b/plugins/modules/pfsense_ca.py index 0c1fb471..e9730cdc 100644 --- a/plugins/modules/pfsense_ca.py +++ b/plugins/modules/pfsense_ca.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -167,162 +170,300 @@ import re from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) PFSENSE_CA_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - method=dict(type='str', default='existing', choices=['internal', 'existing', 'intermediate']), - state=dict(type='str', default='present', choices=['present', 'absent']), - trust=dict(type='bool'), - randomserial=dict(type='bool'), - certificate=dict(type='str'), - crl=dict(default=None, type='str'), - crlname=dict(default=None, type='str'), - crlrefid=dict(default=None, type='str'), - key=dict(type='str', no_log=True), - keytype=dict(type='str', default='RSA', choices=['RSA', 'ECDSA']), + name=dict(required=True, type="str"), + method=dict( + type="str", default="existing", choices=["internal", "existing", "intermediate"] + ), + state=dict(type="str", default="present", choices=["present", "absent"]), + trust=dict(type="bool"), + randomserial=dict(type="bool"), + certificate=dict(type="str"), + crl=dict(default=None, type="str"), + crlname=dict(default=None, type="str"), + crlrefid=dict(default=None, type="str"), + key=dict(type="str", no_log=True), + keytype=dict(type="str", default="RSA", choices=["RSA", "ECDSA"]), ecname=dict( - type='str', - default='prime256v1', + type="str", + default="prime256v1", choices=[ - 'secp112r1', 'secp112r2', 'secp128r1', 'secp128r2', 'secp160k1', 'secp160r1', 'secp160r2', - 'secp192k1', 'secp224k1', 'secp224r1', 'secp256k1', 'secp384r1', 'secp521r1', 'prime192v1', 'prime192v2', 'prime192v3', 'prime239v1', - 'prime239v2', 'prime239v3', 'prime256v1', 'sect113r1', 'sect113r2', 'sect131r1', 'sect131r2', 'sect163k1', 'sect163r1', 'sect163r2', - 'sect193r1', 'sect193r2', 'sect233k1', 'sect233r1', 'sect239k1', 'sect283k1', 'sect283r1', 'sect409k1', 'sect409r1', 'sect571k1', 'sect571r1', - 'c2pnb163v1', 'c2pnb163v2', 'c2pnb163v3', 'c2pnb176v1', 'c2tnb191v1', 'c2tnb191v2', 'c2tnb191v3', 'c2pnb208w1', 'c2tnb239v1', 'c2tnb239v2', - 'c2tnb239v3', 'c2pnb272w1', 'c2pnb304w1', 'c2tnb359v1', 'c2pnb368w1', 'c2tnb431r1', 'wap-wsg-idm-ecid-wtls1', 'wap-wsg-idm-ecid-wtls3', - 'wap-wsg-idm-ecid-wtls4', 'wap-wsg-idm-ecid-wtls5', 'wap-wsg-idm-ecid-wtls6', 'wap-wsg-idm-ecid-wtls7', 'wap-wsg-idm-ecid-wtls8', - 'wap-wsg-idm-ecid-wtls9', 'wap-wsg-idm-ecid-wtls10', 'wap-wsg-idm-ecid-wtls11', 'wap-wsg-idm-ecid-wtls12', 'Oakley-EC2N-3', 'Oakley-EC2N-4', - 'brainpoolP160r1', 'brainpoolP160t1', 'brainpoolP192r1', 'brainpoolP192t1', 'brainpoolP224r1', 'brainpoolP224t1', 'brainpoolP256r1', - 'brainpoolP256t1', 'brainpoolP320r1', 'brainpoolP320t1', 'brainpoolP384r1', 'brainpoolP384t1', 'brainpoolP512r1', 'brainpoolP512t1', 'SM2']), - keylen=dict(type='str', default='2048', choices=["1024", "2048", "3072", "4096", "6144", "7680", "8192", "15360", "16384"]), - digest_alg=dict(type='str', default='sha256', choices=['sha1', 'sha224', 'sha256', 'sha384', 'sha512']), - lifetime=dict(default=3650, type='int'), - dn_commonname=dict(default='internal-ca', type='str'), - dn_country=dict(default='', type='str'), - dn_state=dict(default='', type='str'), - dn_city=dict(default='', type='str'), - dn_organization=dict(default='', type='str'), - dn_organizationalunit=dict(default='', type='str'), - serial=dict(type='int'), + "secp112r1", + "secp112r2", + "secp128r1", + "secp128r2", + "secp160k1", + "secp160r1", + "secp160r2", + "secp192k1", + "secp224k1", + "secp224r1", + "secp256k1", + "secp384r1", + "secp521r1", + "prime192v1", + "prime192v2", + "prime192v3", + "prime239v1", + "prime239v2", + "prime239v3", + "prime256v1", + "sect113r1", + "sect113r2", + "sect131r1", + "sect131r2", + "sect163k1", + "sect163r1", + "sect163r2", + "sect193r1", + "sect193r2", + "sect233k1", + "sect233r1", + "sect239k1", + "sect283k1", + "sect283r1", + "sect409k1", + "sect409r1", + "sect571k1", + "sect571r1", + "c2pnb163v1", + "c2pnb163v2", + "c2pnb163v3", + "c2pnb176v1", + "c2tnb191v1", + "c2tnb191v2", + "c2tnb191v3", + "c2pnb208w1", + "c2tnb239v1", + "c2tnb239v2", + "c2tnb239v3", + "c2pnb272w1", + "c2pnb304w1", + "c2tnb359v1", + "c2pnb368w1", + "c2tnb431r1", + "wap-wsg-idm-ecid-wtls1", + "wap-wsg-idm-ecid-wtls3", + "wap-wsg-idm-ecid-wtls4", + "wap-wsg-idm-ecid-wtls5", + "wap-wsg-idm-ecid-wtls6", + "wap-wsg-idm-ecid-wtls7", + "wap-wsg-idm-ecid-wtls8", + "wap-wsg-idm-ecid-wtls9", + "wap-wsg-idm-ecid-wtls10", + "wap-wsg-idm-ecid-wtls11", + "wap-wsg-idm-ecid-wtls12", + "Oakley-EC2N-3", + "Oakley-EC2N-4", + "brainpoolP160r1", + "brainpoolP160t1", + "brainpoolP192r1", + "brainpoolP192t1", + "brainpoolP224r1", + "brainpoolP224t1", + "brainpoolP256r1", + "brainpoolP256t1", + "brainpoolP320r1", + "brainpoolP320t1", + "brainpoolP384r1", + "brainpoolP384t1", + "brainpoolP512r1", + "brainpoolP512t1", + "SM2", + ], + ), + keylen=dict( + type="str", + default="2048", + choices=[ + "1024", + "2048", + "3072", + "4096", + "6144", + "7680", + "8192", + "15360", + "16384", + ], + ), + digest_alg=dict( + type="str", + default="sha256", + choices=["sha1", "sha224", "sha256", "sha384", "sha512"], + ), + lifetime=dict(default=3650, type="int"), + dn_commonname=dict(default="internal-ca", type="str"), + dn_country=dict(default="", type="str"), + dn_state=dict(default="", type="str"), + dn_city=dict(default="", type="str"), + dn_organization=dict(default="", type="str"), + dn_organizationalunit=dict(default="", type="str"), + serial=dict(type="int"), ) # These are default but not enforced values CA_CREATE_DEFAULT = dict( - randomserial='disabled', - serial='0', - trust='disabled', + randomserial="disabled", + serial="0", + trust="disabled", ) # Booleans that map to different values CA_BOOL_VALUES = dict( - randomserial=('disabled', 'enabled'), - trust=('disabled', 'enabled'), + randomserial=("disabled", "enabled"), + trust=("disabled", "enabled"), ) class PFSenseCAModule(PFSenseModuleBase): - """ module managing pfsense certificate authorities """ + """module managing pfsense certificate authorities""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return PFSENSE_CA_ARGUMENT_SPEC def __init__(self, module, pfsense=None): - super(PFSenseCAModule, self).__init__(module, pfsense, root='pfsense', node='ca', have_refid=True, create_default=CA_CREATE_DEFAULT, - bool_values=CA_BOOL_VALUES) + super(PFSenseCAModule, self).__init__( + module, + pfsense, + root="pfsense", + node="ca", + have_refid=True, + create_default=CA_CREATE_DEFAULT, + bool_values=CA_BOOL_VALUES, + ) self.name = "pfsense_ca" self.refresh_crls = False self.crl = None - cmd = ('require_once("certs.inc");' - '$max_lifetime = cert_get_max_lifetime();' - 'echo json_encode($max_lifetime);') + cmd = ( + 'require_once("certs.inc");' + "$max_lifetime = cert_get_max_lifetime();" + "echo json_encode($max_lifetime);" + ) self.max_lifetime = int(self.pfsense.php(cmd)) ############################## # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params['state'] == 'absent': + if params["state"] == "absent": return - if re.search(r"[\?\>\<\&\/\\\"\']", params['name']): - self.module.fail_json(msg='name contains invalid characters') + if re.search(r"[\?\>\<\&\/\\\"\']", params["name"]): + self.module.fail_json(msg="name contains invalid characters") pattern = re.compile(r"[^a-zA-Z0-9 '/~`!@#$%\^&*()_\-+={}[\]|;:\"<>,.?\\]") - for param in ['dn_commonname', 'dn_state', 'dn_city', 'dn_organization', 'dn_organizationalunit']: + for param in [ + "dn_commonname", + "dn_state", + "dn_city", + "dn_organization", + "dn_organizationalunit", + ]: if re.search(pattern, self.params[param]): - self.module.fail_json(msg=f'{param} contains invalid characters') + self.module.fail_json(msg=f"{param} contains invalid characters") - if params['lifetime'] > self.max_lifetime: - self.module.fail_json(msg=f'Lifetime is longer than the maximum allowed value ({self.max_lifetime})') + if params["lifetime"] > self.max_lifetime: + self.module.fail_json( + msg=f"Lifetime is longer than the maximum allowed value ({self.max_lifetime})" + ) - if params['method'] == 'existing': - if params['certificate'] is None: + if params["method"] == "existing": + if params["certificate"] is None: self.module.fail_json(msg='Missing required argument "certificate"') # TODO - Make sure certificate purpose includes CA - cert = params['certificate'] - if re.match('LS0', cert): + cert = params["certificate"] + if re.match("LS0", cert): cert = base64.b64decode(cert.encode()).decode() lines = cert.splitlines() - if lines[0] == '-----BEGIN CERTIFICATE-----' and lines[-1] == '-----END CERTIFICATE-----': - params['certificate'] = base64.b64encode(cert.encode()).decode() + if ( + lines[0] == "-----BEGIN CERTIFICATE-----" + and lines[-1] == "-----END CERTIFICATE-----" + ): + params["certificate"] = base64.b64encode(cert.encode()).decode() else: - self.module.fail_json(msg='Could not recognize certificate format: %s' % (cert)) + self.module.fail_json( + msg="Could not recognize certificate format: %s" % (cert) + ) - if params['crl'] is not None: - crl = params['crl'] - if re.match('LS0', crl): + if params["crl"] is not None: + crl = params["crl"] + if re.match("LS0", crl): crl = base64.b64decode(crl.encode()).decode() lines = crl.splitlines() - if lines[0] == '-----BEGIN X509 CRL-----' and lines[-1] == '-----END X509 CRL-----': - params['crl'] = base64.b64encode(crl.encode()).decode() + if ( + lines[0] == "-----BEGIN X509 CRL-----" + and lines[-1] == "-----END X509 CRL-----" + ): + params["crl"] = base64.b64encode(crl.encode()).decode() else: - self.module.fail_json(msg='Could not recognize CRL format: %s' % (crl)) + self.module.fail_json( + msg="Could not recognize CRL format: %s" % (crl) + ) - if params['key'] is not None: - ca_key = params['key'] - if re.match('LS0', ca_key): + if params["key"] is not None: + ca_key = params["key"] + if re.match("LS0", ca_key): ca_key = base64.b64decode(ca_key.encode()).decode() lines = ca_key.splitlines() - if lines[0] == '-----BEGIN PRIVATE KEY-----' and lines[-1] == '-----END PRIVATE KEY-----': - params['key'] = base64.b64encode(ca_key.encode()).decode() + if ( + lines[0] == "-----BEGIN PRIVATE KEY-----" + and lines[-1] == "-----END PRIVATE KEY-----" + ): + params["key"] = base64.b64encode(ca_key.encode()).decode() else: - self.module.fail_json(msg='Could not recognize CA key format: %s' % (ca_key)) + self.module.fail_json( + msg="Could not recognize CA key format: %s" % (ca_key) + ) - if params['serial'] is not None: - if int(params['serial']) < 1: - self.module.fail_json(msg='serial must be greater than 0') + if params["serial"] is not None: + if int(params["serial"]) < 1: + self.module.fail_json(msg="serial must be greater than 0") def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() - obj['descr'] = params['name'] - if params['state'] == 'present': - - if params['method'] == 'existing': - if 'certificate' in params and params['certificate'] is not None: - obj['crt'] = params['certificate'] - if params['crl'] is not None: + obj["descr"] = params["name"] + if params["state"] == "present": + if params["method"] == "existing": + if "certificate" in params and params["certificate"] is not None: + obj["crt"] = params["certificate"] + if params["crl"] is not None: self.crl = {} - self.crl['method'] = 'existing' - self.crl['text'] = params['crl'] - self._get_ansible_param(self.crl, 'crlname', fname='descr', force=True, force_value=obj['descr'] + ' CRL') - self._get_ansible_param(self.crl, 'crlrefid', fname='refid') - if params['key'] is not None: - obj['prv'] = params['key'] + self.crl["method"] = "existing" + self.crl["text"] = params["crl"] + self._get_ansible_param( + self.crl, + "crlname", + fname="descr", + force=True, + force_value=obj["descr"] + " CRL", + ) + self._get_ansible_param(self.crl, "crlrefid", fname="refid") + if params["key"] is not None: + obj["prv"] = params["key"] for arg in CA_BOOL_VALUES: - self._get_ansible_param_bool(obj, arg, value=CA_BOOL_VALUES[arg][1], value_false=CA_BOOL_VALUES[arg][0]) + self._get_ansible_param_bool( + obj, + arg, + value=CA_BOOL_VALUES[arg][1], + value_false=CA_BOOL_VALUES[arg][0], + ) - self._get_ansible_param(obj, 'serial') + self._get_ansible_param(obj, "serial") return obj @@ -334,7 +475,11 @@ def _find_crl_for_ca(self, caref): if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple CRLs for caref {0}, you must specify crlname or crlrefid.'.format(caref)) + self.module.fail_json( + msg="Found multiple CRLs for caref {0}, you must specify crlname or crlrefid.".format( + caref + ) + ) else: return None @@ -343,7 +488,11 @@ def _find_crl_by_name(self, crlname): if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple CRLs for name {0}, you must specify crlrefid.'.format(crlname)) + self.module.fail_json( + msg="Found multiple CRLs for name {0}, you must specify crlrefid.".format( + crlname + ) + ) else: return None @@ -352,66 +501,70 @@ def _find_crl_by_refid(self, crlrefid): if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple CRLs for refid {0}. This is an unsupported condition'.format(crlrefid)) + self.module.fail_json( + msg="Found multiple CRLs for refid {0}. This is an unsupported condition".format( + crlrefid + ) + ) else: return None def _copy_and_add_target(self): - """ populate the XML target_elt """ + """populate the XML target_elt""" self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) self.root_elt.insert(self._find_last_element_index(), self.target_elt) if self.crl is not None: - crl_elt = self.pfsense.new_element('crl') - self.crl['caref'] = self.obj['refid'] - if 'refid' not in self.crl: - self.crl['refid'] = self.pfsense.uniqid() + crl_elt = self.pfsense.new_element("crl") + self.crl["caref"] = self.obj["refid"] + if "refid" not in self.crl: + self.crl["refid"] = self.pfsense.uniqid() self.pfsense.copy_dict_to_element(self.crl, crl_elt) - self.diff['after']['crl'] = self.crl['text'] + self.diff["after"]["crl"] = self.crl["text"] self.pfsense.root.append(crl_elt) self.refresh_crls = True def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" (before, changed) = super(PFSenseCAModule, self)._copy_and_update_target() if self.crl is not None: crl_elt = None # If a crlrefid is specified, update it or create a new one with that refid - if self.params['crlrefid'] is not None: - crl_elt = self._find_crl_by_refid(self.params['crlrefid']) - self.crl['refid'] = self.params['crlrefid'] + if self.params["crlrefid"] is not None: + crl_elt = self._find_crl_by_refid(self.params["crlrefid"]) + self.crl["refid"] = self.params["crlrefid"] else: - if self.params['crlname'] is not None: - crl_elt = self._find_crl_by_name(self.params['crlname']) + if self.params["crlname"] is not None: + crl_elt = self._find_crl_by_name(self.params["crlname"]) if crl_elt is None: - crl_elt = self._find_crl_for_ca(self.target_elt.find('refid').text) + crl_elt = self._find_crl_for_ca(self.target_elt.find("refid").text) if crl_elt is None: changed = True - crl_elt = self.pfsense.new_element('crl') - self.crl['caref'] = self.target_elt.find('refid').text - if 'refid' not in self.crl: - self.crl['refid'] = self.pfsense.uniqid() + crl_elt = self.pfsense.new_element("crl") + self.crl["caref"] = self.target_elt.find("refid").text + if "refid" not in self.crl: + self.crl["refid"] = self.pfsense.uniqid() self.pfsense.copy_dict_to_element(self.crl, crl_elt) # Add after the existing ca entry self.pfsense.root.insert(self._find_this_element_index() + 1, crl_elt) self.refresh_crls = True else: - before['crl'] = crl_elt.find('text').text - before['crlname'] = crl_elt.find('descr').text - if 'crlname' not in self.crl: - self.crl['descr'] = before['crlname'] - before['crlrefid'] = crl_elt.find('refid').text - if 'refid' not in self.crl: - self.crl['refid'] = before['crlrefid'] + before["crl"] = crl_elt.find("text").text + before["crlname"] = crl_elt.find("descr").text + if "crlname" not in self.crl: + self.crl["descr"] = before["crlname"] + before["crlrefid"] = crl_elt.find("refid").text + if "refid" not in self.crl: + self.crl["refid"] = before["crlrefid"] if self.pfsense.copy_dict_to_element(self.crl, crl_elt): changed = True self.refresh_crls = True - self.diff['after']['crl'] = self.crl['text'] - self.diff['after']['crlname'] = self.crl['descr'] - self.diff['after']['crlrefid'] = self.crl['refid'] + self.diff["after"]["crl"] = self.crl["text"] + self.diff["after"]["crlname"] = self.crl["descr"] + self.diff["after"]["crlrefid"] = self.crl["refid"] return (before, changed) @@ -419,32 +572,41 @@ def _copy_and_update_target(self): # run # def _update(self): - (dummy, stdout, stderr) = ('', '', '') - if self.params['state'] == 'present': - if self.params['method'] == 'existing': + (dummy, stdout, stderr) = ("", "", "") + if self.params["state"] == "present": + if self.params["method"] == "existing": # ca_import will base64 encode the cert + key and will fix 'caref' for CAs that reference each other # $ca needs to be an existing reference (particularly 'refid' must be set) before calling ca_import # key and serial are optional arguments. TODO - handle key and serial - (dummy, stdout, stderr) = self.pfsense.phpshell(""" + (dummy, stdout, stderr) = self.pfsense.phpshell( + """ $ca =& lookup_ca('{refid}')['item']; ca_import($ca, '{cert}'); write_config('Update CA reference'); ca_setup_trust_store(); - cert_restart_services(ca_get_all_services('{refid}'));""".format(refid=self.target_elt.find('refid').text, - cert=base64.b64decode(self.target_elt.find('crt').text.encode()).decode())) + cert_restart_services(ca_get_all_services('{refid}'));""".format( + refid=self.target_elt.find("refid").text, + cert=base64.b64decode( + self.target_elt.find("crt").text.encode() + ).decode(), + ) + ) if self.refresh_crls: - (dummy, crl_stdout, crl_stderr) = self.pfsense.phpshell(""" + (dummy, crl_stdout, crl_stderr) = self.pfsense.phpshell( + """ require_once("openvpn.inc"); openvpn_refresh_crls(); require_once("vpn.inc"); - ipsec_configure();""") + ipsec_configure();""" + ) stdout += crl_stdout stderr += crl_stderr - if self.params['method'] == 'internal': + if self.params["method"] == "internal": # Create an internal CA - (dummy, stdout, stderr) = self.pfsense.phpshell(""" + (dummy, stdout, stderr) = self.pfsense.phpshell( + """ $caent =& lookup_ca('{refid}'); $ca =& $caent['item']; @@ -483,43 +645,46 @@ def _update(self): $savemsg = sprintf(gettext("Created internal Certificate Authority %s"), $ca['descr']); config_set_path("ca/{{$caent['idx']}}", $ca); write_config($savemsg); - ca_setup_trust_store();""".format(refid=self.target_elt.find('refid').text, - dn_commonname=self.params['dn_commonname'], - dn_country=self.params['dn_country'], - dn_state=self.params['dn_state'], - dn_city=self.params['dn_city'], - dn_organization=self.params['dn_organization'], - dn_organizationalunit=self.params['dn_organizationalunit'], - keylen=self.params['keylen'], - lifetime=self.params['lifetime'], - keytype=self.params['keytype'], - digest_alg=self.params['digest_alg'], - ecname=self.params['ecname'])) + ca_setup_trust_store();""".format( + refid=self.target_elt.find("refid").text, + dn_commonname=self.params["dn_commonname"], + dn_country=self.params["dn_country"], + dn_state=self.params["dn_state"], + dn_city=self.params["dn_city"], + dn_organization=self.params["dn_organization"], + dn_organizationalunit=self.params["dn_organizationalunit"], + keylen=self.params["keylen"], + lifetime=self.params["lifetime"], + keytype=self.params["keytype"], + digest_alg=self.params["digest_alg"], + ecname=self.params["ecname"], + ) + ) return (dummy, stdout, stderr) def _pre_remove_target_elt(self): - self.diff['after'] = {} + self.diff["after"] = {} if self.target_elt is not None: - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) - crl_elt = self._find_crl_for_ca(self.target_elt.find('refid').text) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) + crl_elt = self._find_crl_for_ca(self.target_elt.find("refid").text) self.elements.remove(self.target_elt) if crl_elt is not None: - self.diff['before']['crl'] = crl_elt.find('text').text + self.diff["before"]["crl"] = crl_elt.find("text").text self.root_elt.remove(crl_elt) else: - self.diff['before'] = {} + self.diff["before"] = {} def main(): module = AnsibleModule( - argument_spec=PFSENSE_CA_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=PFSENSE_CA_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSenseCAModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_cert.py b/plugins/modules/pfsense_cert.py index d2043c3a..e5279499 100644 --- a/plugins/modules/pfsense_cert.py +++ b/plugins/modules/pfsense_cert.py @@ -5,12 +5,15 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -153,40 +156,120 @@ import re from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) CERT_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - ca=dict(type='str'), - keytype=dict(type='str', default='RSA', choices=['RSA', 'ECDSA']), - digestalg=dict(type='str', default='sha256', choices=['sha1', 'sha224', 'sha256', 'sha384', 'sha512']), + name=dict(required=True, type="str"), + ca=dict(type="str"), + keytype=dict(type="str", default="RSA", choices=["RSA", "ECDSA"]), + digestalg=dict( + type="str", + default="sha256", + choices=["sha1", "sha224", "sha256", "sha384", "sha512"], + ), ecname=dict( - type='str', - default='prime256v1', + type="str", + default="prime256v1", choices=[ - 'secp112r1', 'secp112r2', 'secp128r1', 'secp128r2', 'secp160k1', 'secp160r1', 'secp160r2', - 'secp192k1', 'secp224k1', 'secp224r1', 'secp256k1', 'secp384r1', 'secp521r1', 'prime192v1', 'prime192v2', 'prime192v3', 'prime239v1', - 'prime239v2', 'prime239v3', 'prime256v1', 'sect113r1', 'sect113r2', 'sect131r1', 'sect131r2', 'sect163k1', 'sect163r1', 'sect163r2', - 'sect193r1', 'sect193r2', 'sect233k1', 'sect233r1', 'sect239k1', 'sect283k1', 'sect283r1', 'sect409k1', 'sect409r1', 'sect571k1', 'sect571r1', - 'c2pnb163v1', 'c2pnb163v2', 'c2pnb163v3', 'c2pnb176v1', 'c2tnb191v1', 'c2tnb191v2', 'c2tnb191v3', 'c2pnb208w1', 'c2tnb239v1', 'c2tnb239v2', - 'c2tnb239v3', 'c2pnb272w1', 'c2pnb304w1', 'c2tnb359v1', 'c2pnb368w1', 'c2tnb431r1', 'wap-wsg-idm-ecid-wtls1', 'wap-wsg-idm-ecid-wtls3', - 'wap-wsg-idm-ecid-wtls4', 'wap-wsg-idm-ecid-wtls5', 'wap-wsg-idm-ecid-wtls6', 'wap-wsg-idm-ecid-wtls7', 'wap-wsg-idm-ecid-wtls8', - 'wap-wsg-idm-ecid-wtls9', 'wap-wsg-idm-ecid-wtls10', 'wap-wsg-idm-ecid-wtls11', 'wap-wsg-idm-ecid-wtls12', 'Oakley-EC2N-3', 'Oakley-EC2N-4', - 'brainpoolP160r1', 'brainpoolP160t1', 'brainpoolP192r1', 'brainpoolP192t1', 'brainpoolP224r1', 'brainpoolP224t1', 'brainpoolP256r1', - 'brainpoolP256t1', 'brainpoolP320r1', 'brainpoolP320t1', 'brainpoolP384r1', 'brainpoolP384t1', 'brainpoolP512r1', 'brainpoolP512t1', 'SM2']), - keylen=dict(type='str', default='2048'), - lifetime=dict(type='str', default='3650'), - dn_country=dict(type='str'), - dn_state=dict(type='str'), - dn_city=dict(type='str'), - dn_organization=dict(type='str'), - dn_organizationalunit=dict(type='str'), - altnames=dict(type='str'), - certificate=dict(type='str'), - key=dict(type='str', no_log=True), - state=dict(type='str', default='present', choices=['present', 'absent']), - method=dict(type='str', default='internal', choices=['internal', 'import']), - certtype=dict(type='str', default='user', choices=['user', 'server']), + "secp112r1", + "secp112r2", + "secp128r1", + "secp128r2", + "secp160k1", + "secp160r1", + "secp160r2", + "secp192k1", + "secp224k1", + "secp224r1", + "secp256k1", + "secp384r1", + "secp521r1", + "prime192v1", + "prime192v2", + "prime192v3", + "prime239v1", + "prime239v2", + "prime239v3", + "prime256v1", + "sect113r1", + "sect113r2", + "sect131r1", + "sect131r2", + "sect163k1", + "sect163r1", + "sect163r2", + "sect193r1", + "sect193r2", + "sect233k1", + "sect233r1", + "sect239k1", + "sect283k1", + "sect283r1", + "sect409k1", + "sect409r1", + "sect571k1", + "sect571r1", + "c2pnb163v1", + "c2pnb163v2", + "c2pnb163v3", + "c2pnb176v1", + "c2tnb191v1", + "c2tnb191v2", + "c2tnb191v3", + "c2pnb208w1", + "c2tnb239v1", + "c2tnb239v2", + "c2tnb239v3", + "c2pnb272w1", + "c2pnb304w1", + "c2tnb359v1", + "c2pnb368w1", + "c2tnb431r1", + "wap-wsg-idm-ecid-wtls1", + "wap-wsg-idm-ecid-wtls3", + "wap-wsg-idm-ecid-wtls4", + "wap-wsg-idm-ecid-wtls5", + "wap-wsg-idm-ecid-wtls6", + "wap-wsg-idm-ecid-wtls7", + "wap-wsg-idm-ecid-wtls8", + "wap-wsg-idm-ecid-wtls9", + "wap-wsg-idm-ecid-wtls10", + "wap-wsg-idm-ecid-wtls11", + "wap-wsg-idm-ecid-wtls12", + "Oakley-EC2N-3", + "Oakley-EC2N-4", + "brainpoolP160r1", + "brainpoolP160t1", + "brainpoolP192r1", + "brainpoolP192t1", + "brainpoolP224r1", + "brainpoolP224t1", + "brainpoolP256r1", + "brainpoolP256t1", + "brainpoolP320r1", + "brainpoolP320t1", + "brainpoolP384r1", + "brainpoolP384t1", + "brainpoolP512r1", + "brainpoolP512t1", + "SM2", + ], + ), + keylen=dict(type="str", default="2048"), + lifetime=dict(type="str", default="3650"), + dn_country=dict(type="str"), + dn_state=dict(type="str"), + dn_city=dict(type="str"), + dn_organization=dict(type="str"), + dn_organizationalunit=dict(type="str"), + altnames=dict(type="str"), + certificate=dict(type="str"), + key=dict(type="str", no_log=True), + state=dict(type="str", default="present", choices=["present", "absent"]), + method=dict(type="str", default="internal", choices=["internal", "import"]), + certtype=dict(type="str", default="user", choices=["user", "server"]), ) CERT_PHP_COMMAND_PREFIX = """ @@ -195,88 +278,98 @@ class PFSenseCertModule(PFSenseModuleBase): - """ module managing pfsense certificates """ + """module managing pfsense certificates""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return CERT_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseCertModule, self).__init__(module, pfsense, root='pfsense', node='cert') + super(PFSenseCertModule, self).__init__( + module, pfsense, root="pfsense", node="cert" + ) self.name = "pfsense_cert" ############################## # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params['state'] == 'absent': + if params["state"] == "absent": return - if params['method'] == 'internal': + if params["method"] == "internal": # An internal CA is required for internal certificate - if params['ca'] is None: - self.module.fail_json(msg='CA is required.') + if params["ca"] is None: + self.module.fail_json(msg="CA is required.") else: - ca = self._find_ca(params['ca']) + ca = self._find_ca(params["ca"]) if ca is not None: - if ca.find('prv') is None: - self.module.fail_json(msg='CA (%s) is not an internal CA' % params['ca']) + if ca.find("prv") is None: + self.module.fail_json( + msg="CA (%s) is not an internal CA" % params["ca"] + ) else: - self.module.fail_json(msg='CA (%s) not found' % params['ca']) + self.module.fail_json(msg="CA (%s) not found" % params["ca"]) # validate Certificate - if params['certificate'] is not None: - cert = params['certificate'] - if re.match('LS0', cert): + if params["certificate"] is not None: + cert = params["certificate"] + if re.match("LS0", cert): cert = base64.b64decode(cert.encode()).decode() lines = cert.splitlines() - if lines[0] == '-----BEGIN CERTIFICATE-----' and lines[-1] == '-----END CERTIFICATE-----': - params['certificate'] = base64.b64encode(cert.encode()).decode() + if ( + lines[0] == "-----BEGIN CERTIFICATE-----" + and lines[-1] == "-----END CERTIFICATE-----" + ): + params["certificate"] = base64.b64encode(cert.encode()).decode() else: - self.module.fail_json(msg='Could not recognize certificate format: %s' % (cert)) + self.module.fail_json( + msg="Could not recognize certificate format: %s" % (cert) + ) # validate key - if params['key'] is not None: - key = params['key'] - if re.match('LS0', key): + if params["key"] is not None: + key = params["key"] + if re.match("LS0", key): key = base64.b64decode(key.encode()).decode() lines = key.splitlines() - if re.match('^-----BEGIN ((EC|RSA) )?PRIVATE KEY-----$', lines[0]) and re.match('^-----END ((EC|RSA) )?PRIVATE KEY-----$', lines[-1]): - params['key'] = base64.b64encode(key.encode()).decode() + if re.match( + "^-----BEGIN ((EC|RSA) )?PRIVATE KEY-----$", lines[0] + ) and re.match("^-----END ((EC|RSA) )?PRIVATE KEY-----$", lines[-1]): + params["key"] = base64.b64encode(key.encode()).decode() else: - self.module.fail_json(msg='Could not recognize key format: %s' % (key)) + self.module.fail_json(msg="Could not recognize key format: %s" % (key)) def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() self.obj = obj # certificate description - obj['descr'] = params['name'] - if params['state'] == 'present': - - if params['ca'] is not None: + obj["descr"] = params["name"] + if params["state"] == "present": + if params["ca"] is not None: # found CA - ca = self._find_ca(params['ca']) + ca = self._find_ca(params["ca"]) if ca is not None: # get CA refid - obj['caref'] = ca.find('refid').text + obj["caref"] = ca.find("refid").text else: - self.module.fail_json(msg='CA (%s) not found' % params['ca']) + self.module.fail_json(msg="CA (%s) not found" % params["ca"]) - if 'certificate' in params and params['certificate'] is not None: - obj['crt'] = params['certificate'] - if 'key' in params and params['key'] is not None: - obj['prv'] = params['key'] + if "certificate" in params and params["certificate"] is not None: + obj["crt"] = params["certificate"] + if "key" in params and params["key"] is not None: + obj["prv"] = params["key"] return obj @@ -288,33 +381,35 @@ def _find_ca(self, caref): if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple CAs for caref {0}.'.format(caref)) + self.module.fail_json(msg="Found multiple CAs for caref {0}.".format(caref)) else: result = self.root_elt.findall("ca[refid='{0}']".format(caref)) if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple CAs for caref {0}.'.format(caref)) + self.module.fail_json( + msg="Found multiple CAs for caref {0}.".format(caref) + ) else: return None def _copy_and_add_target(self): - """ populate the XML target_elt """ + """populate the XML target_elt""" obj = self.obj - obj['refid'] = self.pfsense.uniqid() - self.diff['after'] = obj + obj["refid"] = self.pfsense.uniqid() + self.diff["after"] = obj self.pfsense.copy_dict_to_element(self.obj, self.target_elt) self.root_elt.insert(self._find_last_element_index(), self.target_elt) def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) - self.diff['before'] = before + self.diff["before"] = before changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) return (before, changed) @@ -322,38 +417,49 @@ def _copy_and_update_target(self): # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'" + self.obj['descr'] + "'" + """return obj's name""" + return "'" + self.obj["descr"] + "'" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'descr') + values += self.format_cli_field(self.params, "descr") else: - values += self.format_updated_cli_field(self.obj, before, 'descr', add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "descr", add_comma=(values) + ) return values ############################## # run # def _update(self): - if self.params['state'] == 'present': - if self.params['method'] == 'import': + if self.params["state"] == "present": + if self.params["method"] == "import": # import certificate - return self.pfsense.phpshell(""" + return self.pfsense.phpshell( + """ require_once('certs.inc'); $cert =& lookup_cert('{refid}'); cert_import($cert, '{cert}', '{key}'); $savemsg = sprintf(gettext("Imported certificate %s"), $cert['descr']); write_config($savemsg); cert_restart_services(cert_get_all_services('{refid}')); - """.format(refid=self.target_elt.find('refid').text, - cert=base64.b64decode(self.target_elt.find('crt').text.encode()).decode(), - key=base64.b64decode(self.target_elt.find('prv').text.encode()).decode())) + """.format( + refid=self.target_elt.find("refid").text, + cert=base64.b64decode( + self.target_elt.find("crt").text.encode() + ).decode(), + key=base64.b64decode( + self.target_elt.find("prv").text.encode() + ).decode(), + ) + ) else: # generate internal certificate - return self.pfsense.phpshell(""" + return self.pfsense.phpshell( + """ require_once('certs.inc'); $certent =& lookup_cert('{refid}'); $cert =& $certent['item']; @@ -408,42 +514,43 @@ def _update(self): $savemsg = sprintf(gettext("Created internal certificate %s"), $cert['descr']); write_config($savemsg); cert_restart_services(cert_get_all_services('{refid}')); - """.format(refid=self.target_elt.find('refid').text, - dn_commonname=self.params['name'], - dn_country=self.params['dn_country'], - dn_state=self.params['dn_state'], - dn_city=self.params['dn_city'], - dn_organization=self.params['dn_organization'], - dn_organizationalunit=self.params['dn_organizationalunit'], - altnames=self.params['altnames'], - caref=self.target_elt.find('caref').text, - keylen=self.params['keylen'], - lifetime=self.params['lifetime'], - certtype=self.params['certtype'], - keytype=self.params['keytype'], - digest_alg=self.params['digestalg'], - ecname=self.params['ecname'])) + """.format( + refid=self.target_elt.find("refid").text, + dn_commonname=self.params["name"], + dn_country=self.params["dn_country"], + dn_state=self.params["dn_state"], + dn_city=self.params["dn_city"], + dn_organization=self.params["dn_organization"], + dn_organizationalunit=self.params["dn_organizationalunit"], + altnames=self.params["altnames"], + caref=self.target_elt.find("caref").text, + keylen=self.params["keylen"], + lifetime=self.params["lifetime"], + certtype=self.params["certtype"], + keytype=self.params["keytype"], + digest_alg=self.params["digestalg"], + ecname=self.params["ecname"], + ) + ) else: - return (None, '', '') + return (None, "", "") def _pre_remove_target_elt(self): - self.diff['after'] = {} + self.diff["after"] = {} if self.target_elt is not None: - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) self.elements.remove(self.target_elt) else: - self.diff['before'] = {} + self.diff["before"] = {} def main(): - module = AnsibleModule( - argument_spec=CERT_ARGUMENT_SPEC, - supports_check_mode=True) + module = AnsibleModule(argument_spec=CERT_ARGUMENT_SPEC, supports_check_mode=True) pfmodule = PFSenseCertModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_default_gateway.py b/plugins/modules/pfsense_default_gateway.py index 370733ec..2f4a379a 100644 --- a/plugins/modules/pfsense_default_gateway.py +++ b/plugins/modules/pfsense_default_gateway.py @@ -8,11 +8,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -70,19 +73,21 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.default_gateway import PFSenseDefaultGatewayModule, \ - DEFAULT_GATEWAY_ARGUMENT_SPEC +from ansible_collections.pfsensible.core.plugins.module_utils.default_gateway import ( + PFSenseDefaultGatewayModule, + DEFAULT_GATEWAY_ARGUMENT_SPEC, +) def main(): module = AnsibleModule( - argument_spec=DEFAULT_GATEWAY_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=DEFAULT_GATEWAY_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSenseDefaultGatewayModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_dhcp_server.py b/plugins/modules/pfsense_dhcp_server.py index 23ed1afe..d88c6730 100644 --- a/plugins/modules/pfsense_dhcp_server.py +++ b/plugins/modules/pfsense_dhcp_server.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '6.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "6.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -188,17 +191,20 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.dhcp_server import PFSenseDHCPServerModule, DHCPSERVER_ARGUMENT_SPEC +from ansible_collections.pfsensible.core.plugins.module_utils.dhcp_server import ( + PFSenseDHCPServerModule, + DHCPSERVER_ARGUMENT_SPEC, +) def main(): module = AnsibleModule( - argument_spec=DHCPSERVER_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=DHCPSERVER_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSenseDHCPServerModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_dhcp_static.py b/plugins/modules/pfsense_dhcp_static.py index b5f38aaf..7a22daff 100644 --- a/plugins/modules/pfsense_dhcp_static.py +++ b/plugins/modules/pfsense_dhcp_static.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -160,59 +163,71 @@ import re from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) DHCP_STATIC_ARGUMENT_SPEC = dict( - name=dict(type='str', aliases=['cid']), - macaddr=dict(type='str'), - netif=dict(type='str'), - ipaddr=dict(type='str'), - hostname=dict(type='str'), - descr=dict(type='str'), - filename=dict(type='str'), - rootpath=dict(type='str'), - defaultleasetime=dict(type='str'), - maxleasetime=dict(type='str'), - gateway=dict(type='str'), - domain=dict(type='str'), - domainsearchlist=dict(type='str'), - winsserver=dict(type='list', elements='str'), - dnsserver=dict(type='list', elements='str'), - ntpserver=dict(type='list', elements='str'), - ddnsdomain=dict(type='str'), - ddnsdomainprimary=dict(type='str'), - ddnsdomainsecondary=dict(type='str'), - ddnsdomainkeyname=dict(type='str'), - ddnsdomainkeyalgorithm=dict(type='str', choices=['hmac-md5', 'hmac-sha1', 'hmac-sha224', 'hmac-sha256', 'hmac-sha384', 'hmac-sha512']), - ddnsdomainkey=dict(type='str', no_log=True), - tftp=dict(type='str'), - ldap=dict(type='str'), - nextserver=dict(type='str'), - filename32=dict(type='str'), - filename64=dict(type='str'), - filename32arm=dict(type='str'), - filename64arm=dict(type='str'), - uefihttpboot=dict(type='str'), - numberoptions=dict(type='str'), - arp_table_static_entry=dict(default=False, type='bool'), - state=dict(type='str', default='present', choices=['present', 'absent']), + name=dict(type="str", aliases=["cid"]), + macaddr=dict(type="str"), + netif=dict(type="str"), + ipaddr=dict(type="str"), + hostname=dict(type="str"), + descr=dict(type="str"), + filename=dict(type="str"), + rootpath=dict(type="str"), + defaultleasetime=dict(type="str"), + maxleasetime=dict(type="str"), + gateway=dict(type="str"), + domain=dict(type="str"), + domainsearchlist=dict(type="str"), + winsserver=dict(type="list", elements="str"), + dnsserver=dict(type="list", elements="str"), + ntpserver=dict(type="list", elements="str"), + ddnsdomain=dict(type="str"), + ddnsdomainprimary=dict(type="str"), + ddnsdomainsecondary=dict(type="str"), + ddnsdomainkeyname=dict(type="str"), + ddnsdomainkeyalgorithm=dict( + type="str", + choices=[ + "hmac-md5", + "hmac-sha1", + "hmac-sha224", + "hmac-sha256", + "hmac-sha384", + "hmac-sha512", + ], + ), + ddnsdomainkey=dict(type="str", no_log=True), + tftp=dict(type="str"), + ldap=dict(type="str"), + nextserver=dict(type="str"), + filename32=dict(type="str"), + filename64=dict(type="str"), + filename32arm=dict(type="str"), + filename64arm=dict(type="str"), + uefihttpboot=dict(type="str"), + numberoptions=dict(type="str"), + arp_table_static_entry=dict(default=False, type="bool"), + state=dict(type="str", default="present", choices=["present", "absent"]), ) DHCP_STATIC_REQUIRED_IF = [ - ['arp_table_static_entry', True, ['ipaddr']], + ["arp_table_static_entry", True, ["ipaddr"]], ] DHCP_STATIC_REQUIRED_ONE_OF = [ - ('name', 'macaddr'), + ("name", "macaddr"), ] class PFSenseDHCPStaticModule(PFSenseModuleBase): - """ module managing pfsense dhcp static configuration """ + """module managing pfsense dhcp static configuration""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return DHCP_STATIC_ARGUMENT_SPEC ############################## @@ -221,7 +236,7 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseDHCPStaticModule, self).__init__(module, pfsense) self.name = "pfsense_dhcp_static" - self.dhcpd = self.pfsense.get_element('dhcpd') + self.dhcpd = self.pfsense.get_element("dhcpd") self.root_elt = None self.staticmaps = None @@ -229,55 +244,84 @@ def __init__(self, module, pfsense=None): # params processing # def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params['macaddr'] is not None and re.fullmatch(r'(?:[0-9a-fA-F]{2}:){5}[0-9a-fA-F]{2}', params['macaddr']) is None: - self.module.fail_json(msg='A valid MAC address must be specified.') - - if params['netif'] is not None: - if self.pfsense.is_interface_group(params['netif']): - self.module.fail_json(msg='DHCP cannot be configured for interface groups') + if ( + params["macaddr"] is not None + and re.fullmatch(r"(?:[0-9a-fA-F]{2}:){5}[0-9a-fA-F]{2}", params["macaddr"]) + is None + ): + self.module.fail_json(msg="A valid MAC address must be specified.") + + if params["netif"] is not None: + if self.pfsense.is_interface_group(params["netif"]): + self.module.fail_json( + msg="DHCP cannot be configured for interface groups" + ) else: - netif = self.pfsense.parse_interface(params['netif']) + netif = self.pfsense.parse_interface(params["netif"]) else: netif = None # find staticmaps and determine interface self._find_staticmaps(netif) - if params['ipaddr'] is not None: - addr = ip_address(u'{0}'.format(params['ipaddr'])) + if params["ipaddr"] is not None: + addr = ip_address("{0}".format(params["ipaddr"])) if addr not in self.network: - self.module.fail_json(msg='The IP address must lie in the {0} subnet.'.format(self.netif)) + self.module.fail_json( + msg="The IP address must lie in the {0} subnet.".format(self.netif) + ) def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() self.obj = obj # client identifier - self._get_ansible_param(obj, 'name', fname='cid', force=True) - - if params['state'] == 'present': + self._get_ansible_param(obj, "name", fname="cid", force=True) - self._get_ansible_param(obj, 'macaddr', fname='mac', force=True) + if params["state"] == "present": + self._get_ansible_param(obj, "macaddr", fname="mac", force=True) # Forced options - for option in ['ipaddr', 'hostname', 'descr', 'filename', - 'rootpath', 'defaultleasetime', 'maxleasetime', - 'gateway', 'domain', 'domainsearchlist', - 'ddnsdomain', 'ddnsdomainprimary', 'ddnsdomainsecondary', - 'ddnsdomainkeyname', 'ddnsdomainkeyalgorithm', 'ddnsdomainkey', - 'tftp', 'ldap', 'nextserver', 'filename32', 'filename64', - 'filename32arm', 'filename64arm', 'uefihttpboot', 'numberoptions']: + for option in [ + "ipaddr", + "hostname", + "descr", + "filename", + "rootpath", + "defaultleasetime", + "maxleasetime", + "gateway", + "domain", + "domainsearchlist", + "ddnsdomain", + "ddnsdomainprimary", + "ddnsdomainsecondary", + "ddnsdomainkeyname", + "ddnsdomainkeyalgorithm", + "ddnsdomainkey", + "tftp", + "ldap", + "nextserver", + "filename32", + "filename64", + "filename32arm", + "filename64arm", + "uefihttpboot", + "numberoptions", + ]: self._get_ansible_param(obj, option, force=True) # Non-forced options - for option in ['winsserver', 'dnsserver', 'ntpserver']: + for option in ["winsserver", "dnsserver", "ntpserver"]: self._get_ansible_param(obj, option) # Defaulted options - self._get_ansible_param(obj, 'ddnsdomainkeyalgorithm', force_value='hmac-md5', force=True) + self._get_ansible_param( + obj, "ddnsdomainkeyalgorithm", force_value="hmac-md5", force=True + ) self._get_ansible_param_bool(obj, "arp_table_static_entry", value="") return obj @@ -288,26 +332,30 @@ def _params_to_obj(self): def _is_valid_netif(self, netif): for nic in self.pfsense.interfaces: if nic.tag == netif: - if nic.find('ipaddr') is not None: - ipaddr = nic.find('ipaddr').text + if nic.find("ipaddr") is not None: + ipaddr = nic.find("ipaddr").text if ipaddr is not None: - if nic.find('subnet') is not None: - subnet = int(nic.find('subnet').text) + if nic.find("subnet") is not None: + subnet = int(nic.find("subnet").text) if subnet < 31: - self.network = ip_network(u'{0}/{1}'.format(ipaddr, subnet), strict=False) + self.network = ip_network( + "{0}/{1}".format(ipaddr, subnet), strict=False + ) return True return False def _find_staticmaps(self, netif=None): for e in self.dhcpd: if netif is None or e.tag == netif: - if e.find('enable') is not None: + if e.find("enable") is not None: if self._is_valid_netif(e.tag): if self.root_elt is not None: - self.module.fail_json(msg='Multiple DHCP servers enabled and no netif specified') + self.module.fail_json( + msg="Multiple DHCP servers enabled and no netif specified" + ) self.root_elt = e self.netif = e.tag - self.staticmaps = self.root_elt.findall('staticmap') + self.staticmaps = self.root_elt.findall("staticmap") if netif is not None: break @@ -315,81 +363,110 @@ def _find_staticmaps(self, netif=None): if netif is None: self.module.fail_json(msg="No DHCP configuration") else: - self.module.fail_json(msg="No DHCP configuration found for netif='{0}'".format(netif)) + self.module.fail_json( + msg="No DHCP configuration found for netif='{0}'".format(netif) + ) - self.result['netif'] = netif + self.result["netif"] = netif def _find_target(self): - if self.params['name'] is not None and self.params['macaddr'] is not None: - result = self.root_elt.findall("staticmap[cid='{0}'][mac='{1}']".format(self.params['name'], self.params['macaddr'])) - elif self.params['name'] is not None: - result = self.root_elt.findall("staticmap[cid='{0}']".format(self.params['name'])) + if self.params["name"] is not None and self.params["macaddr"] is not None: + result = self.root_elt.findall( + "staticmap[cid='{0}'][mac='{1}']".format( + self.params["name"], self.params["macaddr"] + ) + ) + elif self.params["name"] is not None: + result = self.root_elt.findall( + "staticmap[cid='{0}']".format(self.params["name"]) + ) else: - result = self.root_elt.findall("staticmap[mac='{0}']".format(self.params['macaddr'])) + result = self.root_elt.findall( + "staticmap[mac='{0}']".format(self.params["macaddr"]) + ) if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple static maps for cid {0}.'.format(self.obj['cid'])) + self.module.fail_json( + msg="Found multiple static maps for cid {0}.".format(self.obj["cid"]) + ) else: return None def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('staticmap') + """create the XML target_elt""" + return self.pfsense.new_element("staticmap") def _copy_and_add_target(self): - """ populate the XML target_elt """ + """populate the XML target_elt""" super(PFSenseDHCPStaticModule, self)._copy_and_add_target() # Reset static map list - self.staticmaps = self.root_elt.findall('staticmap') + self.staticmaps = self.root_elt.findall("staticmap") @staticmethod def _get_params_to_remove(): - """ returns the list of params to remove if they are not set """ - return ['arp_table_static_entry'] + """returns the list of params to remove if they are not set""" + return ["arp_table_static_entry"] ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return "'" + self.obj['cid'] + "'" + """return obj's name""" + return "'" + self.obj["cid"] + "'" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" if before is None: - values += self.format_cli_field(self.params, 'macaddr') - values += self.format_cli_field(self.params, 'ipaddr') - values += self.format_cli_field(self.params, 'arp_table_static_entry', fvalue=self.fvalue_bool, default=False) + values += self.format_cli_field(self.params, "macaddr") + values += self.format_cli_field(self.params, "ipaddr") + values += self.format_cli_field( + self.params, + "arp_table_static_entry", + fvalue=self.fvalue_bool, + default=False, + ) else: - values += self.format_updated_cli_field(self.obj, before, 'macaddr', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'ipaddr', add_comma=(values)) - values += self.format_updated_cli_field(self.obj, before, 'arp_table_static_entry', fvalue=self.fvalue_bool, add_comma=(values)) + values += self.format_updated_cli_field( + self.obj, before, "macaddr", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, before, "ipaddr", add_comma=(values) + ) + values += self.format_updated_cli_field( + self.obj, + before, + "arp_table_static_entry", + fvalue=self.fvalue_bool, + add_comma=(values), + ) return values ############################## # run # def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell(""" + """make the target pfsense reload""" + return self.pfsense.phpshell( + """ require_once("util.inc"); require_once("services.inc"); $retvaldhcp = services_dhcpd_configure(); if ($retvaldhcp == 0) { clear_subsystem_dirty('dhcpd'); - }""") + }""" + ) def _pre_remove_target_elt(self): - self.diff['after'] = {} + self.diff["after"] = {} if self.target_elt is not None: - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) self.staticmaps.remove(self.target_elt) else: - self.diff['before'] = {} + self.diff["before"] = {} def main(): @@ -397,12 +474,13 @@ def main(): argument_spec=DHCP_STATIC_ARGUMENT_SPEC, required_if=DHCP_STATIC_REQUIRED_IF, required_one_of=DHCP_STATIC_REQUIRED_ONE_OF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseDHCPStaticModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_dns_resolver.py b/plugins/modules/pfsense_dns_resolver.py index ff527c1c..e2980935 100644 --- a/plugins/modules/pfsense_dns_resolver.py +++ b/plugins/modules/pfsense_dns_resolver.py @@ -5,12 +5,15 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -298,7 +301,9 @@ """ -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) from ansible.module_utils.basic import AnsibleModule import base64 import re @@ -307,75 +312,104 @@ # TODO: alias for DNS record DNS_RESOLVER_DOMAIN_OVERRIDE_SPEC = dict( - domain=dict(required=True, type='str'), - ip=dict(required=True, type='str'), - descr=dict(type='str'), - tls_hostname=dict(default='', type='str'), - forward_tls_upstream=dict(default=False, type='bool'), + domain=dict(required=True, type="str"), + ip=dict(required=True, type="str"), + descr=dict(type="str"), + tls_hostname=dict(default="", type="str"), + forward_tls_upstream=dict(default=False, type="bool"), ) DNS_RESOLVER_HOST_ALIAS_SPEC = dict( - host=dict(required=True, type='str'), - domain=dict(required=True, type='str'), - description=dict(required=True, type='str'), + host=dict(required=True, type="str"), + domain=dict(required=True, type="str"), + description=dict(required=True, type="str"), ) DNS_RESOLVER_HOST_SPEC = dict( - host=dict(required=True, type='str'), - domain=dict(required=True, type='str'), - ip=dict(required=True, type='str'), - descr=dict(default="", type='str'), - aliases=dict(default=[], type='list', elements='dict', options=DNS_RESOLVER_HOST_ALIAS_SPEC), + host=dict(required=True, type="str"), + domain=dict(required=True, type="str"), + ip=dict(required=True, type="str"), + descr=dict(default="", type="str"), + aliases=dict( + default=[], type="list", elements="dict", options=DNS_RESOLVER_HOST_ALIAS_SPEC + ), ) DNS_RESOLVER_ARGUMENT_SPEC = dict( - state=dict(default='present', choices=['present', 'absent']), - + state=dict(default="present", choices=["present", "absent"]), # General Settings - port=dict(default=None, type='int'), - enablessl=dict(default=False, type='bool'), - sslcert=dict(default="", type='str'), # need transform - tlsport=dict(default=None, type='int'), - active_interface=dict(default=["all"], type='list', elements='str'), - outgoing_interface=dict(default=["all"], type='list', elements='str'), + port=dict(default=None, type="int"), + enablessl=dict(default=False, type="bool"), + sslcert=dict(default="", type="str"), # need transform + tlsport=dict(default=None, type="int"), + active_interface=dict(default=["all"], type="list", elements="str"), + outgoing_interface=dict(default=["all"], type="list", elements="str"), # TODO: Strict Outgoing Network interface Binding: check box option - system_domain_local_zone_type=dict(default='transparent', choices=['deny', 'refuse', 'static', 'transparent', 'typetransparent', 'redirect', 'inform', - 'inform_deny', 'nodefault']), - dnssec=dict(default=True, type='bool'), + system_domain_local_zone_type=dict( + default="transparent", + choices=[ + "deny", + "refuse", + "static", + "transparent", + "typetransparent", + "redirect", + "inform", + "inform_deny", + "nodefault", + ], + ), + dnssec=dict(default=True, type="bool"), # TODO: Python Module: Enable the Python Module. These 3 options omited when disabled # python=dict(default=False, type='bool'), # python_order=dict(default="pre_validator", type='str', choices=["pre_validator", "post_validator"]), # python_script=dict(default="", type='str'), #Not sure what this is or how to handle it. - forwarding=dict(default=False, type='bool'), - forward_tls_upstream=dict(default=False, type='bool'), - regdhcp=dict(default=False, type='bool'), - regdhcpstatic=dict(default=False, type='bool'), - regovpnclients=dict(default=False, type='bool'), - custom_options=dict(default="", type='str'), - hosts=dict(default=[], type='list', elements='dict', options=DNS_RESOLVER_HOST_SPEC), - domainoverrides=dict(type='list', elements='dict', options=DNS_RESOLVER_DOMAIN_OVERRIDE_SPEC), + forwarding=dict(default=False, type="bool"), + forward_tls_upstream=dict(default=False, type="bool"), + regdhcp=dict(default=False, type="bool"), + regdhcpstatic=dict(default=False, type="bool"), + regovpnclients=dict(default=False, type="bool"), + custom_options=dict(default="", type="str"), + hosts=dict( + default=[], type="list", elements="dict", options=DNS_RESOLVER_HOST_SPEC + ), + domainoverrides=dict( + type="list", elements="dict", options=DNS_RESOLVER_DOMAIN_OVERRIDE_SPEC + ), # Advanced Settings - hideidentity=dict(default=True, type='bool'), - hideversion=dict(default=True, type='bool'), + hideidentity=dict(default=True, type="bool"), + hideversion=dict(default=True, type="bool"), # TODO: Query Name Minimization # TODO: Strict Query Name Minimization - prefetch=dict(default=False, type='bool'), - prefetchkey=dict(default=False, type='bool'), - dnssecstripped=dict(default=True, type='bool'), + prefetch=dict(default=False, type="bool"), + prefetchkey=dict(default=False, type="bool"), + dnssecstripped=dict(default=True, type="bool"), # TODO: Serve Expired # TODO: Aggressive NSEC - msgcachesize=dict(default=4, type='int', choices=[4, 10, 20, 50, 100, 250, 512]), - outgoing_num_tcp=dict(default=10, type='int', choices=[0, 10, 20, 30, 50]), - incoming_num_tcp=dict(default=10, type='int', choices=[0, 10, 20, 30, 50]), - edns_buffer_size=dict(default="auto", type='str', choices=["auto", "512", "1220", "1232", "1432", "1480", "4096"]), - num_queries_per_thread=dict(default=512, type='int', choices=[512, 1024, 2048]), - jostle_timeout=dict(default=200, type='int', choices=[100, 200, 500, 1000]), - cache_max_ttl=dict(default=86400, type='int'), - cache_min_ttl=dict(default=0, type='int'), - infra_host_ttl=dict(default=900, type='int', choices=[60, 120, 300, 600, 900]), - infra_cache_numhosts=dict(default=10000, type='int', choices=[1000, 5000, 10000, 20000, 50000, 100000, 200000]), - unwanted_reply_threshold=dict(default="disabled", type='str', choices=["disabled", "5000000", "10000000", "20000000", "40000000", "50000000"]), - log_verbosity=dict(default=1, type='int', choices=[0, 1, 2, 3, 4, 5]) + msgcachesize=dict(default=4, type="int", choices=[4, 10, 20, 50, 100, 250, 512]), + outgoing_num_tcp=dict(default=10, type="int", choices=[0, 10, 20, 30, 50]), + incoming_num_tcp=dict(default=10, type="int", choices=[0, 10, 20, 30, 50]), + edns_buffer_size=dict( + default="auto", + type="str", + choices=["auto", "512", "1220", "1232", "1432", "1480", "4096"], + ), + num_queries_per_thread=dict(default=512, type="int", choices=[512, 1024, 2048]), + jostle_timeout=dict(default=200, type="int", choices=[100, 200, 500, 1000]), + cache_max_ttl=dict(default=86400, type="int"), + cache_min_ttl=dict(default=0, type="int"), + infra_host_ttl=dict(default=900, type="int", choices=[60, 120, 300, 600, 900]), + infra_cache_numhosts=dict( + default=10000, + type="int", + choices=[1000, 5000, 10000, 20000, 50000, 100000, 200000], + ), + unwanted_reply_threshold=dict( + default="disabled", + type="str", + choices=["disabled", "5000000", "10000000", "20000000", "40000000", "50000000"], + ), + log_verbosity=dict(default=1, type="int", choices=[0, 1, 2, 3, 4, 5]), # TODO: Disable Auto-added Access Control # TODO: Disable Auto-added Host Entries # TODO: Experimental Bit 0x20 Support @@ -386,11 +420,11 @@ class PFSenseDNSResolverModule(PFSenseModuleBase): - """ module managing pfsense dns resolver (unbound) """ + """module managing pfsense dns resolver (unbound)""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return DNS_RESOLVER_ARGUMENT_SPEC ############################## @@ -399,18 +433,20 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseDNSResolverModule, self).__init__(module, pfsense) self.name = "pfsense_dns_resolver" - self.root_elt = self.pfsense.get_element('unbound') + self.root_elt = self.pfsense.get_element("unbound") self.obj = dict() self.interface_elt = None self.dynamic = False if self.root_elt is None: - self.root_elt = self.pfsense.new_element('unbound') + self.root_elt = self.pfsense.new_element("unbound") self.pfsense.root.append(self.root_elt) - cmd = ('require_once("interfaces.inc");;' - '$iflist = get_possible_listen_ips(true);' - 'echo json_encode($iflist);') + cmd = ( + 'require_once("interfaces.inc");;' + "$iflist = get_possible_listen_ips(true);" + "echo json_encode($iflist);" + ) self.iflist = self.pfsense.php(cmd) def _get_interface_name(self, iface: str): @@ -422,29 +458,38 @@ def _get_interface_name(self, iface: str): if ifacelow == iname.lower() or ifacelow == idescr.lower(): return iname # Virtual IPs are listed in the format "IP" or "IP (Description)" - allow specifying either IP or Description - if re.match(f"{re.escape(ifacelow)}(?: \\(|$)", idescr.lower()) or re.search(f" \\({re.escape(ifacelow)}\\)$", idescr.lower()): + if re.match( + f"{re.escape(ifacelow)}(?: \\(|$)", idescr.lower() + ) or re.search(f" \\({re.escape(ifacelow)}\\)$", idescr.lower()): return iname self.module.fail_json(msg=f"Invalid interface '{iface}'") def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() if params["state"] == "present": - obj["enable"] = "" - obj["active_interface"] = ",".join(self._get_interface_name(x) for x in params["active_interface"]) - obj["outgoing_interface"] = ",".join(self._get_interface_name(x) for x in params["outgoing_interface"]) - obj["custom_options"] = base64.b64encode(bytes(params['custom_options'], 'utf-8')).decode() + obj["active_interface"] = ",".join( + self._get_interface_name(x) for x in params["active_interface"] + ) + obj["outgoing_interface"] = ",".join( + self._get_interface_name(x) for x in params["outgoing_interface"] + ) + obj["custom_options"] = base64.b64encode( + bytes(params["custom_options"], "utf-8") + ).decode() self._get_ansible_param_bool(obj, "hideidentity", value="") self._get_ansible_param_bool(obj, "hideversion", value="") self._get_ansible_param_bool(obj, "dnssecstripped", value="") self._get_ansible_param(obj, "port") self._get_ansible_param(obj, "tlsport") if params["sslcert"]: - obj["sslcertref"] = self.pfsense.find_cert_elt(params["sslcert"]).find("refid").text + obj["sslcertref"] = ( + self.pfsense.find_cert_elt(params["sslcert"]).find("refid").text + ) self._get_ansible_param_bool(obj, "forwarding", value="") self._get_ansible_param(obj, "system_domain_local_zone_type") self._get_ansible_param_bool(obj, "regdhcp", value="") @@ -470,20 +515,28 @@ def _params_to_obj(self): self._get_ansible_param(obj, "hosts") self._get_ansible_param(obj, "domainoverrides") for domainoverride in obj.get("domainoverrides", []): - self._get_ansible_param_bool(domainoverride, "forward_tls_upstream", value="", params=domainoverride) - - if ((self.pfsense.config_get_path('system/dnslocalhost') != 'remote') and ("lo0" not in obj['active_interface']) and - ("all" not in obj['active_interface'])): - self.module.fail_json(msg="This system is configured to use the DNS Resolver as its DNS server, so Localhost or All must be selected in" - " active_interface.") + self._get_ansible_param_bool( + domainoverride, + "forward_tls_upstream", + value="", + params=domainoverride, + ) + + if ( + (self.pfsense.config_get_path("system/dnslocalhost") != "remote") + and ("lo0" not in obj["active_interface"]) + and ("all" not in obj["active_interface"]) + ): + self.module.fail_json( + msg="This system is configured to use the DNS Resolver as its DNS server, so Localhost or All must be selected in" + " active_interface." + ) # wrap to all hosts.alias for host in obj["hosts"]: if host["aliases"]: tmp_aliases = host["aliases"] - host["aliases"] = { - "item": tmp_aliases - } + host["aliases"] = {"item": tmp_aliases} else: # Default is an empty element host["aliases"] = "" @@ -491,47 +544,64 @@ def _params_to_obj(self): return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params if params["sslcert"] and not self.pfsense.find_cert_elt(params["sslcert"]): - self.module.fail_json(msg=f'sslcert, {params["sslcert"]} is not a valid description of cert') + self.module.fail_json( + msg=f"sslcert, {params['sslcert']} is not a valid description of cert" + ) for host in params["hosts"]: for ipaddr in host["ip"].split(","): if not self.pfsense.is_ipv4_address(ipaddr): - self.module.fail_json(msg=f'ip, {ipaddr} is not a ipv4 address') + self.module.fail_json(msg=f"ip, {ipaddr} is not a ipv4 address") if params["domainoverrides"] is not None: for domain in params["domainoverrides"]: if not self.pfsense.is_ipv4_address(domain["ip"]): - self.module.fail_json(msg=f'ip, {domain["ip"]} is not a ipv4 address') + self.module.fail_json( + msg=f"ip, {domain['ip']} is not a ipv4 address" + ) ############################## # XML processing # def _create_target(self): - """ create the XML target_elt """ + """create the XML target_elt""" return self.root_elt def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" return self.root_elt def _get_params_to_remove(self): - """ returns the list of params to remove if they are not set """ + """returns the list of params to remove if they are not set""" if self.params["state"] == "absent": return ["enable"] else: - return ["hideidentity", "hideversion", "dnssecstripped", "forwarding", "regdhcp", "regdhcpstatic", "regovpnclients", "enablessl", "dnssec", - "forward_tls_upstream", "prefetch", "prefetchkey"] + return [ + "hideidentity", + "hideversion", + "dnssecstripped", + "forwarding", + "regdhcp", + "regdhcpstatic", + "regovpnclients", + "enablessl", + "dnssec", + "forward_tls_upstream", + "prefetch", + "prefetchkey", + ] ############################## # run # def _update(self): - """ make the target pfsense reload """ - return self.pfsense.phpshell(''' + """make the target pfsense reload""" + return self.pfsense.phpshell( + """ require_once("unbound.inc"); require_once("pfsense-utils.inc"); require_once("system.inc"); @@ -540,47 +610,181 @@ def _update(self): system_resolvconf_generate(); system_dhcpleases_configure(); clear_subsystem_dirty("unbound"); -''') +""" + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ + """return obj's name""" return self.name def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' - - values += self.format_updated_cli_field(self.obj, before, 'enable', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'active_interface', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'outgoing_interface', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'custom_options', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'hideidentity', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'hideversion', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'dnssecstripped', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'port', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'tlsport', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'sslcertref', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'forwarding', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'system_domain_local_zone_type', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'regdhcp', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'regdhcpstatic', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'prefetch', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'prefetchkey', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'msgcachesize', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'outgoing_num_tcp', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'incoming_num_tcp', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'edns_buffer_size', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'num_queries_per_thread', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'jostle_timeout', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'cache_max_ttl', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'cache_min_ttl', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'infra_host_ttl', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'infra_cache_numhosts', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'unwanted_reply_threshold', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, before, 'log_verbosity', add_comma=(values), log_none=False) + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" + + values += self.format_updated_cli_field( + self.obj, + before, + "enable", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, before, "active_interface", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "outgoing_interface", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, + before, + "custom_options", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "hideidentity", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "hideversion", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "dnssecstripped", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "port", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "tlsport", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "sslcertref", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "forwarding", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "system_domain_local_zone_type", + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "regdhcp", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "regdhcpstatic", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "prefetch", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + before, + "prefetchkey", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, before, "msgcachesize", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "outgoing_num_tcp", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "incoming_num_tcp", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "edns_buffer_size", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, + before, + "num_queries_per_thread", + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, before, "jostle_timeout", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "cache_max_ttl", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "cache_min_ttl", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "infra_host_ttl", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, before, "infra_cache_numhosts", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + self.obj, + before, + "unwanted_reply_threshold", + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, before, "log_verbosity", add_comma=(values), log_none=False + ) # todo: hosts and domainoverrides is not logged return values @@ -590,12 +794,13 @@ def main(): module = AnsibleModule( argument_spec=DNS_RESOLVER_ARGUMENT_SPEC, required_if=DNS_RESOLVER_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseDNSResolverModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_gateway.py b/plugins/modules/pfsense_gateway.py index 98b2501a..1defc901 100644 --- a/plugins/modules/pfsense_gateway.py +++ b/plugins/modules/pfsense_gateway.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -105,19 +108,24 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.gateway import PFSenseGatewayModule, GATEWAY_ARGUMENT_SPEC, GATEWAY_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.gateway import ( + PFSenseGatewayModule, + GATEWAY_ARGUMENT_SPEC, + GATEWAY_REQUIRED_IF, +) def main(): module = AnsibleModule( argument_spec=GATEWAY_ARGUMENT_SPEC, required_if=GATEWAY_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseGatewayModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_group.py b/plugins/modules/pfsense_group.py index da42e98a..33eb09f1 100644 --- a/plugins/modules/pfsense_group.py +++ b/plugins/modules/pfsense_group.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -72,60 +75,68 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) GROUP_PHP_COMMAND_PREFIX = """ require_once('auth.inc'); """ -GROUP_PHP_COMMAND_SET = GROUP_PHP_COMMAND_PREFIX + """ +GROUP_PHP_COMMAND_SET = ( + GROUP_PHP_COMMAND_PREFIX + + """ $group = config_get_path('system/group')[{idx}]; local_group_set($group); """ +) # This runs after we remove the group from the config so we can't use it -GROUP_PHP_COMMAND_DEL = GROUP_PHP_COMMAND_PREFIX + """ +GROUP_PHP_COMMAND_DEL = ( + GROUP_PHP_COMMAND_PREFIX + + """ $group['name'] = '{name}'; local_group_del($group); """ +) class PFSenseGroupModule(PFSenseModuleBase): - """ module managing pfsense user groups """ + """module managing pfsense user groups""" def __init__(self, module, pfsense=None): super(PFSenseGroupModule, self).__init__(module, pfsense) self.name = "pfsense_group" - self.root_elt = self.pfsense.get_element('system') - self.groups = self.root_elt.findall('group') + self.root_elt = self.pfsense.get_element("system") + self.groups = self.root_elt.findall("group") ############################## # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = dict() self.obj = obj - obj['name'] = params['name'] - state = params['state'] + obj["name"] = params["name"] + state = params["state"] - if state == 'present': - obj['description'] = params['descr'] - for option in ['scope', 'gid', 'priv']: + if state == "present": + obj["description"] = params["descr"] + for option in ["scope", "gid", "priv"]: if option in params and params[option] is not None: obj[option] = params[option] return obj def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" def _nextgid(self): - """ return and update netgid counter """ - nextgid_elt = self.root_elt.find('nextgid') + """return and update netgid counter""" + nextgid_elt = self.root_elt.find("nextgid") nextgid = nextgid_elt.text nextgid_elt.text = str(int(nextgid) + 1) return nextgid @@ -134,40 +145,48 @@ def _nextgid(self): # XML processing # def _copy_and_add_target(self): - """ create the XML target_elt """ - if 'gid' not in self.obj: + """create the XML target_elt""" + if "gid" not in self.obj: # Search for an open gid while True: - self.obj['gid'] = self._nextgid() - if self._find_group_by_gid(self.obj['gid']) is None: + self.obj["gid"] = self._nextgid() + if self._find_group_by_gid(self.obj["gid"]) is None: break else: - if self._find_group_by_gid(self.obj['gid']) is not None: - self.module.fail_json(msg='A different group already exists with gid {0}.'.format(self.obj['gid'])) + if self._find_group_by_gid(self.obj["gid"]) is not None: + self.module.fail_json( + msg="A different group already exists with gid {0}.".format( + self.obj["gid"] + ) + ) self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) self.root_elt.insert(self._find_last_group_index(), self.target_elt) # Reset groups list - self.groups = self.root_elt.findall('group') + self.groups = self.root_elt.findall("group") def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) - self.diff['before'] = before + self.diff["before"] = before changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'].update(self.pfsense.element_to_dict(self.target_elt)) + self.diff["after"].update(self.pfsense.element_to_dict(self.target_elt)) return (before, changed) def _create_target(self): - """ create the XML target_elt """ - return self.pfsense.new_element('group') + """create the XML target_elt""" + return self.pfsense.new_element("group") def _find_target(self): - return self.pfsense.find_elt('group', self.obj['name'], search_field='name', root_elt=self.root_elt) + return self.pfsense.find_elt( + "group", self.obj["name"], search_field="name", root_elt=self.root_elt + ) def _find_group_by_gid(self, gid): - return self.pfsense.find_elt('group', gid, search_field='gid', root_elt=self.root_elt) + return self.pfsense.find_elt( + "group", gid, search_field="gid", root_elt=self.root_elt + ) def _find_this_group_index(self): return self.groups.index(self.target_elt) @@ -179,46 +198,45 @@ def _find_last_group_index(self): # run # def _update(self): - if self.params['state'] == 'present': - return self.pfsense.phpshell(GROUP_PHP_COMMAND_SET.format(idx=self._find_this_group_index())) + if self.params["state"] == "present": + return self.pfsense.phpshell( + GROUP_PHP_COMMAND_SET.format(idx=self._find_this_group_index()) + ) else: - return self.pfsense.phpshell(GROUP_PHP_COMMAND_DEL.format(name=self.obj['name'])) + return self.pfsense.phpshell( + GROUP_PHP_COMMAND_DEL.format(name=self.obj["name"]) + ) ############################## # Logging # def _get_obj_name(self): - """ return obj's name """ - return self.obj['name'] + """return obj's name""" + return self.obj["name"] def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" return values def main(): module = AnsibleModule( argument_spec={ - 'name': {'required': True, 'type': 'str'}, - 'state': { - 'required': True, - 'choices': ['present', 'absent'] - }, - 'descr': {'required': False, 'type': 'str'}, - 'scope': { - 'default': 'local', - 'choices': ['local', 'remote', 'system'] - }, - 'gid': {'required': False, 'type': 'str'}, - 'priv': {'required': False, 'type': 'list', 'elements': 'str'}, + "name": {"required": True, "type": "str"}, + "state": {"required": True, "choices": ["present", "absent"]}, + "descr": {"required": False, "type": "str"}, + "scope": {"default": "local", "choices": ["local", "remote", "system"]}, + "gid": {"required": False, "type": "str"}, + "priv": {"required": False, "type": "list", "elements": "str"}, }, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseGroupModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_haproxy_backend.py b/plugins/modules/pfsense_haproxy_backend.py index 79c55bed..cebe1e92 100644 --- a/plugins/modules/pfsense_haproxy_backend.py +++ b/plugins/modules/pfsense_haproxy_backend.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -123,18 +126,21 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.haproxy_backend import PFSenseHaproxyBackendModule, HAPROXY_BACKEND_ARGUMENT_SPEC +from ansible_collections.pfsensible.core.plugins.module_utils.haproxy_backend import ( + PFSenseHaproxyBackendModule, + HAPROXY_BACKEND_ARGUMENT_SPEC, +) def main(): module = AnsibleModule( - argument_spec=HAPROXY_BACKEND_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=HAPROXY_BACKEND_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSenseHaproxyBackendModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_haproxy_backend_server.py b/plugins/modules/pfsense_haproxy_backend_server.py index 241fb327..6555808f 100644 --- a/plugins/modules/pfsense_haproxy_backend_server.py +++ b/plugins/modules/pfsense_haproxy_backend_server.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -147,12 +150,13 @@ def main(): module = AnsibleModule( argument_spec=HAPROXY_BACKEND_SERVER_ARGUMENT_SPEC, mutually_exclusive=HAPROXY_BACKEND_SERVER_MUTUALLY_EXCLUSIVE, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseHaproxyBackendServerModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_interface.py b/plugins/modules/pfsense_interface.py index 8798c6b8..b8460ec3 100644 --- a/plugins/modules/pfsense_interface.py +++ b/plugins/modules/pfsense_interface.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -144,7 +147,7 @@ PFSenseInterfaceModule, INTERFACE_ARGUMENT_SPEC, INTERFACE_REQUIRED_IF, - INTERFACE_MUTUALLY_EXCLUSIVE + INTERFACE_MUTUALLY_EXCLUSIVE, ) @@ -153,12 +156,13 @@ def main(): argument_spec=INTERFACE_ARGUMENT_SPEC, required_if=INTERFACE_REQUIRED_IF, mutually_exclusive=INTERFACE_MUTUALLY_EXCLUSIVE, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseInterfaceModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_interface_group.py b/plugins/modules/pfsense_interface_group.py index 9bfcccb7..04f24224 100644 --- a/plugins/modules/pfsense_interface_group.py +++ b/plugins/modules/pfsense_interface_group.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -78,7 +81,7 @@ from ansible_collections.pfsensible.core.plugins.module_utils.interface_group import ( PFSenseInterfaceGroupModule, INTERFACE_GROUP_ARGUMENT_SPEC, - INTERFACE_GROUP_REQUIRED_IF + INTERFACE_GROUP_REQUIRED_IF, ) @@ -86,12 +89,13 @@ def main(): module = AnsibleModule( argument_spec=INTERFACE_GROUP_ARGUMENT_SPEC, required_if=INTERFACE_GROUP_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseInterfaceGroupModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_ipsec.py b/plugins/modules/pfsense_ipsec.py index 3c16f252..52b5abd6 100644 --- a/plugins/modules/pfsense_ipsec.py +++ b/plugins/modules/pfsense_ipsec.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -204,19 +207,24 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec import PFSenseIpsecModule, IPSEC_ARGUMENT_SPEC, IPSEC_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec import ( + PFSenseIpsecModule, + IPSEC_ARGUMENT_SPEC, + IPSEC_REQUIRED_IF, +) def main(): module = AnsibleModule( argument_spec=IPSEC_ARGUMENT_SPEC, required_if=IPSEC_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseIpsecModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_ipsec_aggregate.py b/plugins/modules/pfsense_ipsec_aggregate.py index 2166a83e..d5666c19 100644 --- a/plugins/modules/pfsense_ipsec_aggregate.py +++ b/plugins/modules/pfsense_ipsec_aggregate.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -424,21 +427,39 @@ sample: ["create ipsec_p2 'test_p2' on 'test_tunnel', disabled='False', mode='vti', local='1.2.3.1', ...", "delete ipsec_p2 'test_p2' on 'test_tunnel'"] """ -from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import PFSenseModule -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec import PFSenseIpsecModule, IPSEC_ARGUMENT_SPEC, IPSEC_REQUIRED_IF -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import PFSenseIpsecProposalModule -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import IPSEC_PROPOSAL_ARGUMENT_SPEC -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import IPSEC_PROPOSAL_REQUIRED_IF -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import PFSenseIpsecP2Module -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import IPSEC_P2_ARGUMENT_SPEC -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import IPSEC_P2_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import ( + PFSenseModule, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec import ( + PFSenseIpsecModule, + IPSEC_ARGUMENT_SPEC, + IPSEC_REQUIRED_IF, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import ( + PFSenseIpsecProposalModule, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import ( + IPSEC_PROPOSAL_ARGUMENT_SPEC, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import ( + IPSEC_PROPOSAL_REQUIRED_IF, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import ( + PFSenseIpsecP2Module, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import ( + IPSEC_P2_ARGUMENT_SPEC, +) +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import ( + IPSEC_P2_REQUIRED_IF, +) from ansible.module_utils.basic import AnsibleModule from copy import deepcopy class PFSenseModuleIpsecAggregate(object): - """ module managing pfsense aggregated IPsec tunnels, phases 1, phases 2 and proposals """ + """module managing pfsense aggregated IPsec tunnels, phases 1, phases 2 and proposals""" def __init__(self, module): self.module = module @@ -448,64 +469,68 @@ def __init__(self, module): self.pfsense_ipsec_p2 = PFSenseIpsecP2Module(module, self.pfsense) def _update(self): - if self.pfsense_ipsec.result['changed'] or self.pfsense_ipsec_proposal.result['changed'] or self.pfsense_ipsec_p2.result['changed']: + if ( + self.pfsense_ipsec.result["changed"] + or self.pfsense_ipsec_proposal.result["changed"] + or self.pfsense_ipsec_p2.result["changed"] + ): return self.pfsense.apply_ipsec_changes() - return ('', '', '') + return ("", "", "") @staticmethod def want_ipsec(ipsec_elt, ipsecs): - """ return True if we want to keep ipsec_elt """ - descr = ipsec_elt.find('descr') + """return True if we want to keep ipsec_elt""" + descr = ipsec_elt.find("descr") if descr is None: return True for ipsec in ipsecs: - if ipsec['state'] == 'absent': + if ipsec["state"] == "absent": continue - if ipsec['descr'] == descr.text: + if ipsec["descr"] == descr.text: return True return False def proposal_elt_to_params(self, ipsec_elt, proposal_elt): - """ return the pfsense_ipsec_proposal params corresponding the proposal_elt """ + """return the pfsense_ipsec_proposal params corresponding the proposal_elt""" params = {} proposal = self.pfsense.element_to_dict(proposal_elt) - params['encryption'] = proposal['encryption-algorithm']['name'] - params['key_length'] = proposal['encryption-algorithm'].get('keylen') - if params['key_length'] is not None: - if params['key_length'] == '': - params['key_length'] = None + params["encryption"] = proposal["encryption-algorithm"]["name"] + params["key_length"] = proposal["encryption-algorithm"].get("keylen") + if params["key_length"] is not None: + if params["key_length"] == "": + params["key_length"] = None else: - params['key_length'] = int(params['key_length']) - params['hash'] = proposal['hash-algorithm'] - params['dhgroup'] = int(proposal['dhgroup']) - descr_elt = ipsec_elt.find('descr') + params["key_length"] = int(params["key_length"]) + params["hash"] = proposal["hash-algorithm"] + params["dhgroup"] = int(proposal["dhgroup"]) + descr_elt = ipsec_elt.find("descr") if descr_elt is None: - params['descr'] = '' + params["descr"] = "" else: - params['descr'] = descr_elt.text + params["descr"] = descr_elt.text if self.pfsense.is_at_least_2_5_0(): - params['prf'] = proposal['prf-algorithm'] + params["prf"] = proposal["prf-algorithm"] return params def want_ipsec_proposal(self, ipsec_elt, proposal_elt, proposals): - """ return True if we want to keep proposal_elt """ + """return True if we want to keep proposal_elt""" params_from_elt = self.proposal_elt_to_params(ipsec_elt, proposal_elt) - params_from_elt['state'] = 'present' + params_from_elt["state"] = "present" if proposals is not None: for proposal in proposals: _proposal = deepcopy(proposal) - _proposal.pop('apply', None) + _proposal.pop("apply", None) if not self.pfsense.is_at_least_2_5_0(): - _proposal.pop('prf', None) - elif _proposal.get('prf') is None: - _proposal.pop('prf', None) - params_from_elt.pop('prf', None) + _proposal.pop("prf", None) + elif _proposal.get("prf") is None: + _proposal.pop("prf", None) + params_from_elt.pop("prf", None) if params_from_elt == _proposal: return True @@ -513,32 +538,32 @@ def want_ipsec_proposal(self, ipsec_elt, proposal_elt, proposals): return False def want_ipsec_phase2(self, phase2_elt, phases2): - """ return True if we want to keep proposal_elt """ - ikeid_elt = phase2_elt.find('ikeid') - descr = phase2_elt.find('descr') + """return True if we want to keep proposal_elt""" + ikeid_elt = phase2_elt.find("ikeid") + descr = phase2_elt.find("descr") if descr is None or ikeid_elt is None: return True - phase1_elt = self.pfsense.find_ipsec_phase1(ikeid_elt.text, 'ikeid') + phase1_elt = self.pfsense.find_ipsec_phase1(ikeid_elt.text, "ikeid") if phase1_elt is None: return True - phase1_descr_elt = phase1_elt.find('descr') + phase1_descr_elt = phase1_elt.find("descr") if phase1_descr_elt is None: return True p1_descr = phase1_descr_elt.text if phases2 is not None: for phase2 in phases2: - if phase2['state'] == 'absent': + if phase2["state"] == "absent": continue - if phase2['descr'] == descr.text and phase2['p1_descr'] == p1_descr: + if phase2["descr"] == descr.text and phase2["p1_descr"] == p1_descr: return True return False def run_ipsecs(self): - """ process input params to add/update/delete all IPsec tunnels """ - want = self.module.params['aggregated_ipsecs'] + """process input params to add/update/delete all IPsec tunnels""" + want = self.module.params["aggregated_ipsecs"] # processing aggregated parameter if want is not None: @@ -546,25 +571,25 @@ def run_ipsecs(self): self.pfsense_ipsec.run(param) # delete every other if required - if self.module.params['purge_ipsecs']: + if self.module.params["purge_ipsecs"]: todel = [] for ipsec_elt in self.pfsense_ipsec.root_elt: - if ipsec_elt.tag != 'phase1': + if ipsec_elt.tag != "phase1": continue if not self.want_ipsec(ipsec_elt, want): params = {} - params['state'] = 'absent' - params['apply'] = False - params['descr'] = ipsec_elt.find('descr').text - params['ikeid'] = ipsec_elt.find('ikeid').text + params["state"] = "absent" + params["apply"] = False + params["descr"] = ipsec_elt.find("descr").text + params["ikeid"] = ipsec_elt.find("ikeid").text todel.append(params) for params in todel: self.pfsense_ipsec.run(params) def run_ipsec_proposals(self): - """ process input params to add/update/delete all IPsec tunnels """ - want = self.module.params['aggregated_ipsec_proposals'] + """process input params to add/update/delete all IPsec tunnels""" + want = self.module.params["aggregated_ipsec_proposals"] # processing aggregated parameter if want is not None: @@ -572,32 +597,32 @@ def run_ipsec_proposals(self): self.pfsense_ipsec_proposal.run(param) # delete every other if required - if self.module.params['purge_ipsec_proposals']: + if self.module.params["purge_ipsec_proposals"]: todel = [] for ipsec_elt in self.pfsense_ipsec_proposal.ipsec: - if ipsec_elt.tag != 'phase1': + if ipsec_elt.tag != "phase1": continue - encryption_elt = ipsec_elt.find('encryption') + encryption_elt = ipsec_elt.find("encryption") if encryption_elt is None: continue - items_elt = encryption_elt.findall('item') + items_elt = encryption_elt.findall("item") for proposal_elt in items_elt: if not self.want_ipsec_proposal(ipsec_elt, proposal_elt, want): params = self.proposal_elt_to_params(ipsec_elt, proposal_elt) - params['state'] = 'absent' - params['apply'] = False - params['descr'] = ipsec_elt.find('descr').text - params['ikeid'] = ipsec_elt.find('ikeid').text + params["state"] = "absent" + params["apply"] = False + params["descr"] = ipsec_elt.find("descr").text + params["ikeid"] = ipsec_elt.find("ikeid").text todel.append(params) for params in todel: self.pfsense_ipsec_proposal.run(params) def run_ipsec_p2s(self): - """ process input params to add/update/delete all IPsec tunnels """ - want = self.module.params['aggregated_ipsec_p2s'] + """process input params to add/update/delete all IPsec tunnels""" + want = self.module.params["aggregated_ipsec_p2s"] # processing aggregated parameter if want is not None: @@ -605,61 +630,91 @@ def run_ipsec_p2s(self): self.pfsense_ipsec_p2.run(param) # delete every other if required - if self.module.params['purge_ipsec_p2s']: + if self.module.params["purge_ipsec_p2s"]: todel = [] for phase2_elt in self.pfsense_ipsec_p2.root_elt: - if phase2_elt.tag != 'phase2': + if phase2_elt.tag != "phase2": continue if not self.want_ipsec_phase2(phase2_elt, want): params = {} - params['state'] = 'absent' - params['apply'] = False - params['descr'] = phase2_elt.find('descr').text - params['p1_descr'] = self.pfsense.find_ipsec_phase1(phase2_elt.find('ikeid').text, 'ikeid').find('descr').text - params['ikeid'] = phase2_elt.find('ikeid').text + params["state"] = "absent" + params["apply"] = False + params["descr"] = phase2_elt.find("descr").text + params["p1_descr"] = ( + self.pfsense.find_ipsec_phase1( + phase2_elt.find("ikeid").text, "ikeid" + ) + .find("descr") + .text + ) + params["ikeid"] = phase2_elt.find("ikeid").text todel.append(params) for params in todel: self.pfsense_ipsec_p2.run(params) def commit_changes(self): - """ apply changes and exit module """ - stdout = '' - stderr = '' - changed = self.pfsense_ipsec.result['changed'] or self.pfsense_ipsec_proposal.result['changed'] or self.pfsense_ipsec_p2.result['changed'] + """apply changes and exit module""" + stdout = "" + stderr = "" + changed = ( + self.pfsense_ipsec.result["changed"] + or self.pfsense_ipsec_proposal.result["changed"] + or self.pfsense_ipsec_p2.result["changed"] + ) if changed and not self.module.check_mode: - self.pfsense.write_config(descr='aggregated change') - if self.module.params['apply']: + self.pfsense.write_config(descr="aggregated change") + if self.module.params["apply"]: (dummy, stdout, stderr) = self._update() result = {} - result['result_ipsecs'] = self.pfsense_ipsec.result['commands'] - result['result_ipsec_proposals'] = self.pfsense_ipsec_proposal.result['commands'] - result['result_ipsec_p2s'] = self.pfsense_ipsec_p2.result['commands'] - result['changed'] = changed - result['stdout'] = stdout - result['stderr'] = stderr + result["result_ipsecs"] = self.pfsense_ipsec.result["commands"] + result["result_ipsec_proposals"] = self.pfsense_ipsec_proposal.result[ + "commands" + ] + result["result_ipsec_p2s"] = self.pfsense_ipsec_p2.result["commands"] + result["changed"] = changed + result["stdout"] = stdout + result["stderr"] = stderr self.module.exit_json(**result) def main(): argument_spec = dict( - aggregated_ipsecs=dict(type='list', elements='dict', options=IPSEC_ARGUMENT_SPEC, required_if=IPSEC_REQUIRED_IF), - aggregated_ipsec_proposals=dict(type='list', elements='dict', options=IPSEC_PROPOSAL_ARGUMENT_SPEC, required_if=IPSEC_PROPOSAL_REQUIRED_IF), - aggregated_ipsec_p2s=dict(type='list', elements='dict', options=IPSEC_P2_ARGUMENT_SPEC, required_if=IPSEC_P2_REQUIRED_IF), - purge_ipsecs=dict(default=False, type='bool'), - purge_ipsec_proposals=dict(default=False, type='bool'), - purge_ipsec_p2s=dict(default=False, type='bool'), - apply=dict(default=True, type='bool'), + aggregated_ipsecs=dict( + type="list", + elements="dict", + options=IPSEC_ARGUMENT_SPEC, + required_if=IPSEC_REQUIRED_IF, + ), + aggregated_ipsec_proposals=dict( + type="list", + elements="dict", + options=IPSEC_PROPOSAL_ARGUMENT_SPEC, + required_if=IPSEC_PROPOSAL_REQUIRED_IF, + ), + aggregated_ipsec_p2s=dict( + type="list", + elements="dict", + options=IPSEC_P2_ARGUMENT_SPEC, + required_if=IPSEC_P2_REQUIRED_IF, + ), + purge_ipsecs=dict(default=False, type="bool"), + purge_ipsec_proposals=dict(default=False, type="bool"), + purge_ipsec_p2s=dict(default=False, type="bool"), + apply=dict(default=True, type="bool"), ) - required_one_of = [['aggregated_ipsecs', 'aggregated_ipsec_proposals', 'aggregated_ipsec_p2s']] + required_one_of = [ + ["aggregated_ipsecs", "aggregated_ipsec_proposals", "aggregated_ipsec_p2s"] + ] module = AnsibleModule( argument_spec=argument_spec, required_one_of=required_one_of, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseModuleIpsecAggregate(module) @@ -670,5 +725,5 @@ def main(): pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_ipsec_p2.py b/plugins/modules/pfsense_ipsec_p2.py index 1778125a..a8df7d9a 100644 --- a/plugins/modules/pfsense_ipsec_p2.py +++ b/plugins/modules/pfsense_ipsec_p2.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -185,19 +188,24 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import PFSenseIpsecP2Module, IPSEC_P2_ARGUMENT_SPEC, IPSEC_P2_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import ( + PFSenseIpsecP2Module, + IPSEC_P2_ARGUMENT_SPEC, + IPSEC_P2_REQUIRED_IF, +) def main(): module = AnsibleModule( argument_spec=IPSEC_P2_ARGUMENT_SPEC, required_if=IPSEC_P2_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseIpsecP2Module(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_ipsec_proposal.py b/plugins/modules/pfsense_ipsec_proposal.py index 889546c0..f157da24 100644 --- a/plugins/modules/pfsense_ipsec_proposal.py +++ b/plugins/modules/pfsense_ipsec_proposal.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -101,7 +104,7 @@ from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import ( PFSenseIpsecProposalModule, IPSEC_PROPOSAL_ARGUMENT_SPEC, - IPSEC_PROPOSAL_REQUIRED_IF + IPSEC_PROPOSAL_REQUIRED_IF, ) @@ -109,12 +112,13 @@ def main(): module = AnsibleModule( argument_spec=IPSEC_PROPOSAL_ARGUMENT_SPEC, required_if=IPSEC_PROPOSAL_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseIpsecProposalModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_log_settings.py b/plugins/modules/pfsense_log_settings.py index 93ef606d..12d5dd29 100644 --- a/plugins/modules/pfsense_log_settings.py +++ b/plugins/modules/pfsense_log_settings.py @@ -6,12 +6,15 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -200,67 +203,67 @@ import re from copy import deepcopy from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) LOG_SETTINGS_ARGUMENT_SPEC = dict( - logformat=dict(required=False, type='str', - choices=['rfc3164', 'rfc5424']), - reverse=dict(required=False, type='bool'), - nentries=dict(required=False, type='int'), - nologdefaultblock=dict(required=False, type='bool'), - nologdefaultpass=dict(required=False, type='bool'), - nologbogons=dict(required=False, type='bool'), - nologprivatenets=dict(required=False, type='bool'), - nologlinklocal4=dict(required=False, type='bool'), - nologsnort2c=dict(required=False, type='bool'), - nolognginx=dict(required=False, type='bool'), - logconfigchanges=dict(required=False, type='bool'), - rawfilter=dict(required=False, type='bool'), - filterdescriptions=dict(required=False, type='int', - choices=[0, 1, 2]), - disablelocallogging=dict(required=False, type='bool'), - logfilesize=dict(required=False, type='int'), - logcompressiontype=dict(required=False, type='str', - choices=['bzip2', 'gzip', 'xz', 'zstd', 'none']), - rotatecount=dict(required=False, type='int'), - enable=dict(required=False, type='bool'), - sourceip=dict(required=False, type='str'), - ipproto=dict(required=False, type='str', - choices=['ipv4', 'ipv6']), - remoteserver=dict(required=False, type='str'), - remoteserver2=dict(required=False, type='str'), - remoteserver3=dict(required=False, type='str'), - logall=dict(required=False, type='bool'), - system=dict(required=False, type='bool'), - logfilter=dict(required=False, type='bool'), - resolver=dict(required=False, type='bool'), - dhcp=dict(required=False, type='bool'), - ppp=dict(required=False, type='bool'), - auth=dict(required=False, type='bool'), - portalauth=dict(required=False, type='bool'), - vpn=dict(required=False, type='bool'), - dpinger=dict(required=False, type='bool'), - routing=dict(required=False, type='bool'), - ntpd=dict(required=False, type='bool'), - hostapd=dict(required=False, type='bool'), + logformat=dict(required=False, type="str", choices=["rfc3164", "rfc5424"]), + reverse=dict(required=False, type="bool"), + nentries=dict(required=False, type="int"), + nologdefaultblock=dict(required=False, type="bool"), + nologdefaultpass=dict(required=False, type="bool"), + nologbogons=dict(required=False, type="bool"), + nologprivatenets=dict(required=False, type="bool"), + nologlinklocal4=dict(required=False, type="bool"), + nologsnort2c=dict(required=False, type="bool"), + nolognginx=dict(required=False, type="bool"), + logconfigchanges=dict(required=False, type="bool"), + rawfilter=dict(required=False, type="bool"), + filterdescriptions=dict(required=False, type="int", choices=[0, 1, 2]), + disablelocallogging=dict(required=False, type="bool"), + logfilesize=dict(required=False, type="int"), + logcompressiontype=dict( + required=False, type="str", choices=["bzip2", "gzip", "xz", "zstd", "none"] + ), + rotatecount=dict(required=False, type="int"), + enable=dict(required=False, type="bool"), + sourceip=dict(required=False, type="str"), + ipproto=dict(required=False, type="str", choices=["ipv4", "ipv6"]), + remoteserver=dict(required=False, type="str"), + remoteserver2=dict(required=False, type="str"), + remoteserver3=dict(required=False, type="str"), + logall=dict(required=False, type="bool"), + system=dict(required=False, type="bool"), + logfilter=dict(required=False, type="bool"), + resolver=dict(required=False, type="bool"), + dhcp=dict(required=False, type="bool"), + ppp=dict(required=False, type="bool"), + auth=dict(required=False, type="bool"), + portalauth=dict(required=False, type="bool"), + vpn=dict(required=False, type="bool"), + dpinger=dict(required=False, type="bool"), + routing=dict(required=False, type="bool"), + ntpd=dict(required=False, type="bool"), + hostapd=dict(required=False, type="bool"), ) # rename the reserved words with log prefix params_map = { - 'logformat': 'format', - 'logfilter': 'filter', + "logformat": "format", + "logfilter": "filter", } # fields with inverted logic -inverted_list = ['nologdefaultpass'] +inverted_list = ["nologdefaultpass"] class PFSenseLogSettingsModule(PFSenseModuleBase): - """ module managing pfsense log settings """ + """module managing pfsense log settings""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return LOG_SETTINGS_ARGUMENT_SPEC ############################## @@ -269,7 +272,7 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseLogSettingsModule, self).__init__(module, pfsense) self.name = "log_settings" - self.root_elt = self.pfsense.get_element('syslog') + self.root_elt = self.pfsense.get_element("syslog") self.target_elt = self.root_elt self.params = dict() self.obj = dict() @@ -282,7 +285,7 @@ def __init__(self, module, pfsense=None): # params processing # def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" params = self.params obj = self.pfsense.element_to_dict(self.root_elt) @@ -293,7 +296,7 @@ def _set_param(target, param): # get possibly mapped settings name _param = params_map.get(param, param) if params.get(param) is not None: - if param == 'sourceip': + if param == "sourceip": target[param] = self._get_source_ip_interface(params[param]) else: if isinstance(params[param], str): @@ -305,14 +308,18 @@ def _set_param_bool(target, param): # get possibly mapped settings name _param = params_map.get(param, param) if params.get(param) is not None: - value = not params.get(param) if param in inverted_list else params.get(param) + value = ( + not params.get(param) + if param in inverted_list + else params.get(param) + ) if value is True and _param not in target: - target[_param] = '' + target[_param] = "" elif value is False and _param in target: del target[_param] for param in LOG_SETTINGS_ARGUMENT_SPEC: - if LOG_SETTINGS_ARGUMENT_SPEC[param]['type'] == 'bool': + if LOG_SETTINGS_ARGUMENT_SPEC[param]["type"] == "bool": _set_param_bool(obj, param) else: _set_param(obj, param) @@ -322,12 +329,12 @@ def _set_param_bool(target, param): def _is_interface_ip_or_descr(self, address): result = False - if address in ['127.0.0.1', 'Localhost']: + if address in ["127.0.0.1", "Localhost"]: return True for interface_elt in self.pfsense.interfaces: - descr = interface_elt.find('descr') - ipaddr = interface_elt.find('ipaddr') + descr = interface_elt.find("descr") + ipaddr = interface_elt.find("ipaddr") if descr is not None and descr.text == address: return True @@ -337,14 +344,14 @@ def _is_interface_ip_or_descr(self, address): return result def _get_interface_by_ip_or_display_name(self, address): - """ return interface_id by ip address or name """ + """return interface_id by ip address or name""" - if address in ['127.0.0.1', 'Localhost']: - return 'lo0' + if address in ["127.0.0.1", "Localhost"]: + return "lo0" for interface_elt in self.pfsense.interfaces: - descr = interface_elt.find('descr') - ipaddr = interface_elt.find('ipaddr') + descr = interface_elt.find("descr") + ipaddr = interface_elt.find("ipaddr") if descr is not None and descr.text == address: return interface_elt.tag @@ -365,17 +372,17 @@ def _get_source_ip_interface(self, address): return result def _validate_syslog_server(self, hostname, name): - """ check hostname / ip address combinations with optional port """ + """check hostname / ip address combinations with optional port""" if not hostname: return host = hostname.lower() - contains_port = re.match(r'^(\[.+\]|[^:]+):[0-9]+$', host) + contains_port = re.match(r"^(\[.+\]|[^:]+):[0-9]+$", host) if contains_port is not None: - host, port = host.rsplit(':', 1) + host, port = host.rsplit(":", 1) # check if we got a ipv6 address with port - need to remove '[' and ']' - host = host.strip('[]') + host = host.strip("[]") if port is not None and (int(port) <= 0 or int(port) >= 65536): self.module.fail_json(msg="Invalid port {0}".format(port)) @@ -386,51 +393,93 @@ def _validate_syslog_server(self, hostname, name): if self.pfsense.is_ipv6_address(host): return - groups = re.match(r'^(?:(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9])\.)*(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9\.])$', host) + groups = re.match( + r"^(?:(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9])\.)*(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9\.])$", + host, + ) if groups is None: - self.module.fail_json(msg="The {0} can only contain the characters A-Z, 0-9 and '-'. It may not start or end with '-'".format(name)) + self.module.fail_json( + msg="The {0} can only contain the characters A-Z, 0-9 and '-'. It may not start or end with '-'".format( + name + ) + ) def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" params = self.params - if params.get('sourceip') is not None: - address = params.get('sourceip') - if address == '': + if params.get("sourceip") is not None: + address = params.get("sourceip") + if address == "": return - if not self.pfsense.is_virtual_ip(address) and not self._is_interface_ip_or_descr(address): - self.module.fail_json(msg="sourceip: Invalid address {address}!".format(address=params.get('sourceip'))) - - if params.get('logall') is True: - for log_param in ['system', 'logfilter', 'resolver', - 'dhcp', 'ppp', 'auth', 'portalauth', - 'vpn', 'dpinger', 'routing', 'ntpd', 'hostapd']: + if not self.pfsense.is_virtual_ip( + address + ) and not self._is_interface_ip_or_descr(address): + self.module.fail_json( + msg="sourceip: Invalid address {address}!".format( + address=params.get("sourceip") + ) + ) + + if params.get("logall") is True: + for log_param in [ + "system", + "logfilter", + "resolver", + "dhcp", + "ppp", + "auth", + "portalauth", + "vpn", + "dpinger", + "routing", + "ntpd", + "hostapd", + ]: if params.get(log_param) is True: - self.module.fail_json(msg="{log_param} = True is invalid when logall is True".format(log_param=log_param)) - - if params.get('enable') is True: - remote_params = ['remoteserver', 'remoteserver2', 'remoteserver3'] - if params.get('remoteserver') is None and params.get('remoteserver2') is None and params.get('remoteserver3') is None: - self.module.fail_json(msg="Need at least one remote syslog server when remote logging is enabled") + self.module.fail_json( + msg="{log_param} = True is invalid when logall is True".format( + log_param=log_param + ) + ) + + if params.get("enable") is True: + remote_params = ["remoteserver", "remoteserver2", "remoteserver3"] + if ( + params.get("remoteserver") is None + and params.get("remoteserver2") is None + and params.get("remoteserver3") is None + ): + self.module.fail_json( + msg="Need at least one remote syslog server when remote logging is enabled" + ) else: for param in remote_params: self._validate_syslog_server(params.get(param), param) - if params.get('nentries') is not None: - nentries = int(params.get('nentries')) + if params.get("nentries") is not None: + nentries = int(params.get("nentries")) if nentries < 5 or nentries > 200000: - self.module.fail_json(msg="nentries must be an integer from 5 to 200000") + self.module.fail_json( + msg="nentries must be an integer from 5 to 200000" + ) - if params.get('logfilesize') is not None: - logfilesize = int(params.get('logfilesize')) + if params.get("logfilesize") is not None: + logfilesize = int(params.get("logfilesize")) if logfilesize < 100000: - self.module.fail_json(msg="logfilesize must be an integer greater or equal than 100000") - elif logfilesize >= (2 ** 32) / 2: - self.module.fail_json(msg="logfilesize is too large: {logfilesize}".format(logfilesize=logfilesize)) - - if params.get('rotatecount') is not None: - rotatecount = int(params.get('rotatecount')) + self.module.fail_json( + msg="logfilesize must be an integer greater or equal than 100000" + ) + elif logfilesize >= (2**32) / 2: + self.module.fail_json( + msg="logfilesize is too large: {logfilesize}".format( + logfilesize=logfilesize + ) + ) + + if params.get("rotatecount") is not None: + rotatecount = int(params.get("rotatecount")) if rotatecount < 0 or rotatecount > 99: self.module.fail_json(msg="rotatecount must be an integer from 0 to 99") @@ -438,12 +487,14 @@ def _validate_params(self): # XML processing # def _remove_deleted_params(self): - """ Remove from target_elt a few deleted params """ + """Remove from target_elt a few deleted params""" changed = False for param in LOG_SETTINGS_ARGUMENT_SPEC: - if LOG_SETTINGS_ARGUMENT_SPEC[param]['type'] == 'bool': + if LOG_SETTINGS_ARGUMENT_SPEC[param]["type"] == "bool": _param = params_map.get(param, param) - if self.pfsense.remove_deleted_param_from_elt(self.target_elt, _param, self.obj): + if self.pfsense.remove_deleted_param_from_elt( + self.target_elt, _param, self.obj + ): changed = True return changed @@ -452,7 +503,7 @@ def _remove_deleted_params(self): # run # def run(self, params): - """ process input params to add/update/delete """ + """process input params to add/update/delete""" self.params = params self.target_elt = self.root_elt self._validate_params() @@ -460,28 +511,45 @@ def run(self, params): self._add() def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" for cmd in self.route_cmds: self.module.run_command(cmd) - cmd = ''' + cmd = """ require_once("filter.inc"); $retval = 0; -$retval |= system_syslogd_start();''' - - for param in ['nologdefaultblock', 'nologdefaultpass', 'nologbogons', 'nologprivatenets', 'nologlinklocal4', 'nologsnort2c']: +$retval |= system_syslogd_start();""" + + for param in [ + "nologdefaultblock", + "nologdefaultpass", + "nologbogons", + "nologprivatenets", + "nologlinklocal4", + "nologsnort2c", + ]: if self.params.get(param) is not None: - if (self.params[param] and param not in self.before or not self.params[param] and param in self.before): - cmd += '$retval |= filter_configure();\n' + if ( + self.params[param] + and param not in self.before + or not self.params[param] + and param in self.before + ): + cmd += "$retval |= filter_configure();\n" break - if self.params.get('nolognginx') is not None: - if (self.params['nolognginx'] and 'nolognginx' not in self.before or not self.params['nolognginx'] and 'nolognginx' in self.before): - cmd += 'ob_flush();\n' - cmd += 'flush();\n' + if self.params.get("nolognginx") is not None: + if ( + self.params["nolognginx"] + and "nolognginx" not in self.before + or not self.params["nolognginx"] + and "nolognginx" in self.before + ): + cmd += "ob_flush();\n" + cmd += "flush();\n" cmd += 'send_event("service restart webgui");\n' - cmd += '$retval |= filter_pflog_start(true);\n' + cmd += "$retval |= filter_pflog_start(true);\n" return self.pfsense.phpshell(cmd) @@ -490,32 +558,41 @@ def _update(self): # @staticmethod def _get_obj_name(): - """ return obj's name """ + """return obj's name""" return "syslog" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - values = '' + """generate pseudo-CLI command fields parameters to create an obj""" + values = "" for param in LOG_SETTINGS_ARGUMENT_SPEC: _param = params_map.get(param, param) - if LOG_SETTINGS_ARGUMENT_SPEC[param]['type'] == 'bool': - values += self.format_updated_cli_field(self.obj, self.before, _param, fvalue=self.fvalue_bool, add_comma=(values), log_none=False) + if LOG_SETTINGS_ARGUMENT_SPEC[param]["type"] == "bool": + values += self.format_updated_cli_field( + self.obj, + self.before, + _param, + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) else: - values += self.format_updated_cli_field(self.obj, self.before, _param, add_comma=(values), log_none=False) + values += self.format_updated_cli_field( + self.obj, self.before, _param, add_comma=(values), log_none=False + ) return values def main(): module = AnsibleModule( - argument_spec=LOG_SETTINGS_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=LOG_SETTINGS_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSenseLogSettingsModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_nat_outbound.py b/plugins/modules/pfsense_nat_outbound.py index da5ec890..d3f19a9f 100644 --- a/plugins/modules/pfsense_nat_outbound.py +++ b/plugins/modules/pfsense_nat_outbound.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -136,12 +139,13 @@ def main(): argument_spec=NAT_OUTBOUND_ARGUMENT_SPEC, mutually_exclusive=NAT_OUTBOUND_MUTUALLY_EXCLUSIVE, required_if=NAT_OUTBOUND_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseNatOutboundModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_nat_port_forward.py b/plugins/modules/pfsense_nat_port_forward.py index 87b93fb8..6870ac4e 100644 --- a/plugins/modules/pfsense_nat_port_forward.py +++ b/plugins/modules/pfsense_nat_port_forward.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -122,7 +125,7 @@ from ansible_collections.pfsensible.core.plugins.module_utils.nat_port_forward import ( PFSenseNatPortForwardModule, NAT_PORT_FORWARD_ARGUMENT_SPEC, - NAT_PORT_FORWARD_REQUIRED_IF + NAT_PORT_FORWARD_REQUIRED_IF, ) @@ -130,12 +133,13 @@ def main(): module = AnsibleModule( argument_spec=NAT_PORT_FORWARD_ARGUMENT_SPEC, required_if=NAT_PORT_FORWARD_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseNatPortForwardModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_openvpn_client.py b/plugins/modules/pfsense_openvpn_client.py index c96bb766..f2267eb1 100644 --- a/plugins/modules/pfsense_openvpn_client.py +++ b/plugins/modules/pfsense_openvpn_client.py @@ -5,9 +5,10 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -DOCUMENTATION = r''' +DOCUMENTATION = r""" --- module: pfsense_openvpn_client short_description: Manage pfSense OpenVPN configuration @@ -220,15 +221,15 @@ required: false default: null type: str -''' +""" -EXAMPLES = r''' +EXAMPLES = r""" - name: "Add OpenVPN client" pfsense_openvpn_client: name: 'OpenVPN Client' -''' +""" -RETURN = r''' +RETURN = r""" shared_key: description: The generated shared key, base64 encoded returned: when `generate` is passed as the shared_key argument and a key is generated. @@ -251,13 +252,13 @@ 4ODNjNDU3NTdlZTVjMWQ4ZDk5ZjM4ZjcKZGNiZDAwZmI3Nzc2ZWFlYjQ1ZmQwOTBjNGNlYTNmMGMKMzgzNDE0ZTJlYmU4MWNiZGIxZmNlN2M2YmFhMDlkMWYKMTU4OGUzNGRkYzUxY2NjOTE5NDNjNTFh OTI2OTE3NWQKNzZiZjdhOWI1ZmM3NDAyNmE3MTVkNGVmODVkYzY2Y2UKMWE5MWQwNjNhODIwZDY4MTc0ODlmYjJkZjNmYzY2MmMKMmU2OWZiMzNiMzM5MjdjYjUyNThkZDQ4M2NkNDE0Y2QKMDJhZWE3Z jA3MmNhZmEwOTY5Yjg5NWVjYzNiYmExNGQKLS0tLS1FTkQgT3BlblZQTiBTdGF0aWMga2V5IFYxLS0tLS0K -''' +""" from ansible.module_utils.basic import AnsibleModule from ansible_collections.pfsensible.core.plugins.module_utils.openvpn_client import ( PFSenseOpenVPNClientModule, OPENVPN_CLIENT_ARGUMENT_SPEC, - OPENVPN_CLIENT_REQUIRED_IF + OPENVPN_CLIENT_REQUIRED_IF, ) @@ -265,12 +266,13 @@ def main(): module = AnsibleModule( argument_spec=OPENVPN_CLIENT_ARGUMENT_SPEC, required_if=OPENVPN_CLIENT_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfopenvpn = PFSenseOpenVPNClientModule(module) pfopenvpn.run(module.params) pfopenvpn.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_openvpn_override.py b/plugins/modules/pfsense_openvpn_override.py index a19e2eba..8fc138e8 100644 --- a/plugins/modules/pfsense_openvpn_override.py +++ b/plugins/modules/pfsense_openvpn_override.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -135,7 +138,7 @@ from ansible_collections.pfsensible.core.plugins.module_utils.openvpn_override import ( PFSenseOpenVPNOverrideModule, OPENVPN_OVERRIDE_ARGUMENT_SPEC, - OPENVPN_OVERRIDE_REQUIRED_IF + OPENVPN_OVERRIDE_REQUIRED_IF, ) @@ -143,12 +146,13 @@ def main(): module = AnsibleModule( argument_spec=OPENVPN_OVERRIDE_ARGUMENT_SPEC, required_if=OPENVPN_OVERRIDE_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseOpenVPNOverrideModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_openvpn_server.py b/plugins/modules/pfsense_openvpn_server.py index badcff8f..e5b7c9f3 100644 --- a/plugins/modules/pfsense_openvpn_server.py +++ b/plugins/modules/pfsense_openvpn_server.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -326,7 +329,7 @@ mode: server_user """ -RETURN = r''' +RETURN = r""" shared_key: description: The generated shared key, base64 encoded returned: when `generate` is passed as the shared_key argument and a key is generated. @@ -354,13 +357,13 @@ returned: always type: int sample: 1 -''' +""" from ansible.module_utils.basic import AnsibleModule from ansible_collections.pfsensible.core.plugins.module_utils.openvpn_server import ( PFSenseOpenVPNServerModule, OPENVPN_SERVER_ARGUMENT_SPEC, - OPENVPN_SERVER_REQUIRED_IF + OPENVPN_SERVER_REQUIRED_IF, ) @@ -368,12 +371,13 @@ def main(): module = AnsibleModule( argument_spec=OPENVPN_SERVER_ARGUMENT_SPEC, required_if=OPENVPN_SERVER_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfopenvpn = PFSenseOpenVPNServerModule(module) pfopenvpn.run(module.params) pfopenvpn.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_phpshell.py b/plugins/modules/pfsense_phpshell.py index 265d40c4..7b82eb57 100644 --- a/plugins/modules/pfsense_phpshell.py +++ b/plugins/modules/pfsense_phpshell.py @@ -5,12 +5,15 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -72,20 +75,20 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) -PHP_SHELL_ARGUMENT_SPEC = dict( - cmd=dict(required=True, type='str') -) +PHP_SHELL_ARGUMENT_SPEC = dict(cmd=dict(required=True, type="str")) class PFSensePHPShellModule(PFSenseModuleBase): - """ module run php code on pfsense """ + """module run php code on pfsense""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return PHP_SHELL_ARGUMENT_SPEC ############################## @@ -94,33 +97,37 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSensePHPShellModule, self).__init__(module, pfsense) self.name = "pfsense_phpshell" - self.result['changed'] = True + self.result["changed"] = True ############################## # run # def run(self, params): - (rc, stdout, stderr) = self.pfsense.phpshell(params['cmd']) - self.result.update({ - 'rc': rc, - 'stdout': stdout, - 'stderr': stderr, - }) + (rc, stdout, stderr) = self.pfsense.phpshell(params["cmd"]) + self.result.update( + { + "rc": rc, + "stdout": stdout, + "stderr": stderr, + } + ) if int(rc) != 0 or len(stderr) > 0: - self.module.fail_json(msg='rc is not 0 or stderr contains output (you still could overwrite with failed_when)') + self.module.fail_json( + msg="rc is not 0 or stderr contains output (you still could overwrite with failed_when)" + ) else: self.module.exit_json(**self.result) def main(): module = AnsibleModule( - argument_spec=PHP_SHELL_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=PHP_SHELL_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSensePHPShellModule(module) pfmodule.run(module.params) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_rewrite_config.py b/plugins/modules/pfsense_rewrite_config.py index a8fccc46..d0ecd1b5 100644 --- a/plugins/modules/pfsense_rewrite_config.py +++ b/plugins/modules/pfsense_rewrite_config.py @@ -5,12 +5,15 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -32,18 +35,20 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) REWRITE_CONFIG_ARGUMENT_SPEC = dict() class PFSenseRewriteConfigModule(PFSenseModuleBase): - """ module managing pfsense routes """ + """module managing pfsense routes""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return REWRITE_CONFIG_ARGUMENT_SPEC ############################## @@ -52,38 +57,38 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseRewriteConfigModule, self).__init__(module, pfsense) self.name = "pfsense_rewrite_config" - self.result['changed'] = True + self.result["changed"] = True ############################## # run # def commit_changes(self): - """ apply changes and exit module """ - self.result['stdout'] = '' - self.result['stderr'] = '' - if self.result['changed'] and not self.module.check_mode: - (dummy, self.result['stdout'], self.result['stderr']) = self._update() + """apply changes and exit module""" + self.result["stdout"] = "" + self.result["stderr"] = "" + if self.result["changed"] and not self.module.check_mode: + (dummy, self.result["stdout"], self.result["stderr"]) = self._update() self.module.exit_json(**self.result) def _update(self): - """ make the target pfsense rewrite the config.xml file """ + """make the target pfsense rewrite the config.xml file""" - cmd = ''' + cmd = """ parse_config(true); -write_config('pfsense_rewrite_config');''' +write_config('pfsense_rewrite_config');""" return self.pfsense.phpshell(cmd) def main(): module = AnsibleModule( - argument_spec=REWRITE_CONFIG_ARGUMENT_SPEC, - supports_check_mode=True) + argument_spec=REWRITE_CONFIG_ARGUMENT_SPEC, supports_check_mode=True + ) pfmodule = PFSenseRewriteConfigModule(module) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_route.py b/plugins/modules/pfsense_route.py index c6ad6bec..9da0bce0 100644 --- a/plugins/modules/pfsense_route.py +++ b/plugins/modules/pfsense_route.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -68,19 +71,24 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.route import PFSenseRouteModule, ROUTE_ARGUMENT_SPEC, ROUTE_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.route import ( + PFSenseRouteModule, + ROUTE_ARGUMENT_SPEC, + ROUTE_REQUIRED_IF, +) def main(): module = AnsibleModule( argument_spec=ROUTE_ARGUMENT_SPEC, required_if=ROUTE_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseRouteModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_rule.py b/plugins/modules/pfsense_rule.py index 0c22db1e..28ca6401 100644 --- a/plugins/modules/pfsense_rule.py +++ b/plugins/modules/pfsense_rule.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -176,19 +179,24 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.rule import PFSenseRuleModule, RULE_ARGUMENT_SPEC, RULE_REQUIRED_IF +from ansible_collections.pfsensible.core.plugins.module_utils.rule import ( + PFSenseRuleModule, + RULE_ARGUMENT_SPEC, + RULE_REQUIRED_IF, +) def main(): module = AnsibleModule( argument_spec=RULE_ARGUMENT_SPEC, required_if=RULE_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseRuleModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_rule_separator.py b/plugins/modules/pfsense_rule_separator.py index fab7850e..f6b5264b 100644 --- a/plugins/modules/pfsense_rule_separator.py +++ b/plugins/modules/pfsense_rule_separator.py @@ -5,11 +5,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -72,10 +75,18 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import PFSenseRuleSeparatorModule -from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import RULE_SEPARATOR_ARGUMENT_SPEC -from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import RULE_SEPARATOR_REQUIRED_ONE_OF -from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import RULE_SEPARATOR_MUTUALLY_EXCLUSIVE +from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import ( + PFSenseRuleSeparatorModule, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import ( + RULE_SEPARATOR_ARGUMENT_SPEC, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import ( + RULE_SEPARATOR_REQUIRED_ONE_OF, +) +from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import ( + RULE_SEPARATOR_MUTUALLY_EXCLUSIVE, +) def main(): @@ -83,12 +94,13 @@ def main(): argument_spec=RULE_SEPARATOR_ARGUMENT_SPEC, required_one_of=RULE_SEPARATOR_REQUIRED_ONE_OF, mutually_exclusive=RULE_SEPARATOR_MUTUALLY_EXCLUSIVE, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseRuleSeparatorModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_setup.py b/plugins/modules/pfsense_setup.py index e3f6ef1d..bf4d3f71 100644 --- a/plugins/modules/pfsense_setup.py +++ b/plugins/modules/pfsense_setup.py @@ -5,12 +5,15 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -185,79 +188,113 @@ from ansible.module_utils.basic import AnsibleModule from ansible_collections.pfsensible.core.plugins.module_utils.arg_route import p2o_cert from ansible_collections.pfsensible.core.plugins.module_utils.arg_validate import validate_cert -from ansible_collections.pfsensible.core.plugins.module_utils.module_config_base import PFSenseModuleConfigBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_config_base import ( + PFSenseModuleConfigBase, +) SETUP_ARGUMENT_SPEC = dict( - hostname=dict(required=False, type='str'), - domain=dict(required=False, type='str'), - dns_addresses=dict(required=False, type='str'), - dns_hostnames=dict(required=False, type='str'), - dns_gateways=dict(required=False, type='str'), - dnsallowoverride=dict(required=False, type='bool'), - dnslocalhost=dict(required=False, type='str', choices=[ - '', - 'local', - 'remote', - 'true', - 'false', - ]), - timezone=dict(required=False, type='str'), - timeservers=dict(required=False, type='str'), + hostname=dict(required=False, type="str"), + domain=dict(required=False, type="str"), + dns_addresses=dict(required=False, type="str"), + dns_hostnames=dict(required=False, type="str"), + dns_gateways=dict(required=False, type="str"), + dnsallowoverride=dict(required=False, type="bool"), + dnslocalhost=dict( + required=False, + type="str", + choices=[ + "", + "local", + "remote", + "true", + "false", + ], + ), + timezone=dict(required=False, type="str"), + timeservers=dict(required=False, type="str"), language=dict( required=False, - type='str', - choices=['bs', 'de_DE', 'en_US', 'es_AR', 'es_ES', 'fr_FR', 'it_IT', 'ko_FR', 'nb_NO', 'nl_NL', 'pl_PL', 'pt_BR', 'pt_PT', 'ru_RU', 'zh_CN', - 'zh_Hans_CN', 'zh_Hans_HK', 'zh_Hant_TW'], + type="str", + choices=[ + "bs", + "de_DE", + "en_US", + "es_AR", + "es_ES", + "fr_FR", + "it_IT", + "ko_FR", + "nb_NO", + "nl_NL", + "pl_PL", + "pt_BR", + "pt_PT", + "ru_RU", + "zh_CN", + "zh_Hans_CN", + "zh_Hans_HK", + "zh_Hant_TW", + ], ), - session_timeout=dict(required=False, type='int'), - authmode=dict(required=False, type='str'), - shellauth=dict(required=False, type='bool'), - webguicert=dict(required=False, type='str'), - webguicss=dict(required=False, type='str'), - webguifixedmenu=dict(required=False, type='bool'), - webguihostnamemenu=dict(required=False, type='str', choices=['nohost', 'hostonly', 'fqdn']), - dashboardcolumns=dict(required=False, type='int'), - interfacessort=dict(required=False, type='bool'), - dashboardavailablewidgetspanel=dict(required=False, type='bool'), - systemlogsfilterpanel=dict(required=False, type='bool'), - systemlogsmanagelogpanel=dict(required=False, type='bool'), - statusmonitoringsettingspanel=dict(required=False, type='bool'), - requirestatefilter=dict(required=False, type='bool'), - webguileftcolumnhyper=dict(required=False, type='bool'), - disablealiaspopupdetail=dict(required=False, type='bool'), - roworderdragging=dict(required=False, type='bool'), - logincss=dict(required=False, type='str'), - loginshowhost=dict(required=False, type='bool'), + session_timeout=dict(required=False, type="int"), + authmode=dict(required=False, type="str"), + shellauth=dict(required=False, type="bool"), + webguicert=dict(required=False, type="str"), + webguicss=dict(required=False, type="str"), + webguifixedmenu=dict(required=False, type="bool"), + webguihostnamemenu=dict(required=False, type="str", choices=["nohost", "hostonly", "fqdn"]), + dashboardcolumns=dict(required=False, type="int"), + interfacessort=dict(required=False, type="bool"), + dashboardavailablewidgetspanel=dict(required=False, type="bool"), + systemlogsfilterpanel=dict(required=False, type="bool"), + systemlogsmanagelogpanel=dict(required=False, type="bool"), + statusmonitoringsettingspanel=dict(required=False, type="bool"), + requirestatefilter=dict(required=False, type="bool"), + webguileftcolumnhyper=dict(required=False, type="bool"), + disablealiaspopupdetail=dict(required=False, type="bool"), + roworderdragging=dict(required=False, type="bool"), + logincss=dict(required=False, type="str"), + loginshowhost=dict(required=False, type="bool"), ) def p2o_dnslocalhost(self, name, params, obj): if params[name] is not None: - if str(params.get(name)).lower() in ['', 'false']: - obj[name] = '' - elif str(params.get(name)).lower() in ['remote', 'true']: - obj[name] = 'remote' - elif params.get(name).lower() == 'local': - obj[name] = 'local' + if str(params.get(name)).lower() in ["", "false"]: + obj[name] = "" + elif str(params.get(name)).lower() in ["remote", "true"]: + obj[name] = "remote" + elif params.get(name).lower() == "local": + obj[name] = "local" def p2o_webguicss(self, name, params, obj): if params[name] is not None: # Add .css suffix if not present - if params[name][-4:] != '.css': - obj[name] = params[name] + '.css' + if params[name][-4:] != ".css": + obj[name] = params[name] + ".css" else: obj[name] = params[name] def validate_webguicss(self, webguicss): - """ check css style """ - path = '/usr/local/www/css/' - themes = [f for f in listdir(path) if isfile(join(path, f)) and f.endswith('.css') and f.find('login') == -1 and f.find('logo') == -1] - themes = map(lambda x: x.replace('.css', ''), themes) - if webguicss.rstrip('.css') not in themes: - raise ValueError("The submitted theme '%s' could not be found. Pick a different theme." % webguicss) + """check css style""" + path = "/usr/local/www/css/" + themes = [ + f + for f in listdir(path) + if isfile(join(path, f)) + and f.endswith(".css") + and f.find("login") == -1 + and f.find("logo") == -1 + ] + themes = map(lambda x: x.replace(".css", ""), themes) + if webguicss.rstrip(".css") not in themes: + raise ValueError( + "The submitted theme '%s' could not be found. Pick a different theme." + % webguicss + ) SETUP_ARG_ROUTE = dict( @@ -268,46 +305,54 @@ def validate_webguicss(self, webguicss): # Booleans that map to different values SETUP_BOOL_VALUES = dict( - webguifixedmenu=(None, 'fixed'), + webguifixedmenu=(None, "fixed"), ) SETUP_MAP_PARAM = [ - ('authmode', 'webgui/authmode'), - ('dashboardavailablewidgetspanel', 'webgui/dashboardavailablewidgetspanel'), - ('dashboardcolumns', 'webgui/dashboardcolumns'), - ('disablealiaspopupdetail', 'webgui/disablealiaspopupdetail'), - ('interfacessort', 'webgui/interfacessort'), - ('logincss', 'webgui/logincss'), - ('loginshowhost', 'webgui/loginshowhost'), - ('requirestatefilter', 'webgui/requirestatefilter'), - ('roworderdragging', 'webgui/roworderdragging'), - ('session_timeout', 'webgui/session_timeout'), - ('shellauth', 'webgui/shellauth'), - ('statusmonitoringsettingspanel', 'webgui/statusmonitoringsettingspanel'), - ('systemlogsfilterpanel', 'webgui/systemlogsfilterpanel'), - ('systemlogsmanagelogpanel', 'webgui/systemlogsmanagelogpanel'), - ('webguicert', 'webgui/ssl-certref'), - ('webguicss', 'webgui/webguicss'), - ('webguifixedmenu', 'webgui/webguifixedmenu'), - ('webguihostnamemenu', 'webgui/webguihostnamemenu'), - ('webguileftcolumnhyper', 'webgui/webguileftcolumnhyper'), + ("authmode", "webgui/authmode"), + ("dashboardavailablewidgetspanel", "webgui/dashboardavailablewidgetspanel"), + ("dashboardcolumns", "webgui/dashboardcolumns"), + ("disablealiaspopupdetail", "webgui/disablealiaspopupdetail"), + ("interfacessort", "webgui/interfacessort"), + ("logincss", "webgui/logincss"), + ("loginshowhost", "webgui/loginshowhost"), + ("requirestatefilter", "webgui/requirestatefilter"), + ("roworderdragging", "webgui/roworderdragging"), + ("session_timeout", "webgui/session_timeout"), + ("shellauth", "webgui/shellauth"), + ("statusmonitoringsettingspanel", "webgui/statusmonitoringsettingspanel"), + ("systemlogsfilterpanel", "webgui/systemlogsfilterpanel"), + ("systemlogsmanagelogpanel", "webgui/systemlogsmanagelogpanel"), + ("webguicert", "webgui/ssl-certref"), + ("webguicss", "webgui/webguicss"), + ("webguifixedmenu", "webgui/webguifixedmenu"), + ("webguihostnamemenu", "webgui/webguihostnamemenu"), + ("webguileftcolumnhyper", "webgui/webguileftcolumnhyper"), ] class PFSenseSetupModule(PFSenseModuleConfigBase): - """ module managing pfsense routes """ + """module managing pfsense routes""" @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return SETUP_ARGUMENT_SPEC ############################## # init # def __init__(self, module, pfsense=None): - super(PFSenseSetupModule, self).__init__(module, pfsense, name='pfsense_setup', root='system', arg_route=SETUP_ARG_ROUTE, bool_style='absent/present', - bool_values=SETUP_BOOL_VALUES, map_param=SETUP_MAP_PARAM) + super(PFSenseSetupModule, self).__init__( + module, + pfsense, + name="pfsense_setup", + root="system", + arg_route=SETUP_ARG_ROUTE, + bool_style="absent/present", + bool_values=SETUP_BOOL_VALUES, + map_param=SETUP_MAP_PARAM, + ) self.route_cmds = list() self.params_to_delete = list() @@ -315,52 +360,54 @@ def __init__(self, module, pfsense=None): # params processing # def _dns_params_to_obj(self, params, obj): - """ set the dns servers from params to obj """ + """set the dns servers from params to obj""" dns_addresses = None dns_hostnames = [] dns_gateways = [] idx = 0 - if params.get('dns_addresses') is not None: - dns_addresses = params['dns_addresses'].split() - del obj['dns_addresses'] - if params.get('dns_hostnames') is not None: - dns_hostnames = params['dns_hostnames'].split() - del obj['dns_hostnames'] - if params.get('dns_gateways') is not None: - dns_gateways = params['dns_gateways'].split() - del obj['dns_gateways'] + if params.get("dns_addresses") is not None: + dns_addresses = params["dns_addresses"].split() + del obj["dns_addresses"] + if params.get("dns_hostnames") is not None: + dns_hostnames = params["dns_hostnames"].split() + del obj["dns_hostnames"] + if params.get("dns_gateways") is not None: + dns_gateways = params["dns_gateways"].split() + del obj["dns_gateways"] if dns_addresses is not None: # set the servers - obj['dnsserver'] = dns_addresses + obj["dnsserver"] = dns_addresses # set the names & gateways for address in dns_addresses: - gateway = 'none' - if idx < len(dns_hostnames) and dns_hostnames[idx] != 'none': - obj['dns{0}host'.format(idx + 1)] = dns_hostnames[idx] - if idx < len(dns_gateways) and dns_gateways[idx] != 'none': + gateway = "none" + if idx < len(dns_hostnames) and dns_hostnames[idx] != "none": + obj["dns{0}host".format(idx + 1)] = dns_hostnames[idx] + if idx < len(dns_gateways) and dns_gateways[idx] != "none": gateway = dns_gateways[idx] - gw_key = 'dns{0}gw'.format(idx + 1) + gw_key = "dns{0}gw".format(idx + 1) if gw_key not in obj or gateway != obj[gw_key]: obj[gw_key] = gateway if self.pfsense.is_ipv4_address(address): - self.route_cmds.append('/sbin/route delete {0}'.format(address)) + self.route_cmds.append("/sbin/route delete {0}".format(address)) elif self.pfsense.is_ipv6_address(address): - self.route_cmds.append('/sbin/route delete -inet6 {0}'.format(address)) + self.route_cmds.append( + "/sbin/route delete -inet6 {0}".format(address) + ) idx += 1 - elif 'dnsserver' in obj: + elif "dnsserver" in obj: # no servers - del obj['dnsserver'] + del obj["dnsserver"] idx += 1 # delete everything required while True: - host = 'dns{0}host'.format(idx) - gateway = 'dns{0}gw'.format(idx) + host = "dns{0}host".format(idx) + gateway = "dns{0}gw".format(idx) if host not in obj and gateway not in obj: break if host in obj: @@ -372,122 +419,177 @@ def _dns_params_to_obj(self, params, obj): idx += 1 def _params_to_obj(self): - """ return a dict from module params """ + """return a dict from module params""" obj = super(PFSenseSetupModule, self)._params_to_obj() self._dns_params_to_obj(self.params, obj) return obj def _validate_hostname(self, hostname, name, strict=False): - """ check hostname, if strict is true, check if domain is omitted """ + """check hostname, if strict is true, check if domain is omitted""" host = hostname.lower() - groups = re.match(r'^(?:(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9])\.)*(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9\.])$', host) + groups = re.match( + r"^(?:(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9])\.)*(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9\.])$", + host, + ) if groups is None: - self.module.fail_json(msg="The {0} can only contain the characters A-Z, 0-9 and '-'. It may not start or end with '-'".format(name)) + self.module.fail_json( + msg="The {0} can only contain the characters A-Z, 0-9 and '-'. It may not start or end with '-'".format( + name + ) + ) if strict: - groups = re.match(r'^(?:[a-z0-9_]|[a-z0-9_][a-z0-9_\-]*[a-z0-9_])$', host) + groups = re.match(r"^(?:[a-z0-9_]|[a-z0-9_][a-z0-9_\-]*[a-z0-9_])$", host) if groups is None: - self.module.fail_json(msg='A valid {0} is specified, but the domain name part should be omitted'.format(name)) + self.module.fail_json( + msg="A valid {0} is specified, but the domain name part should be omitted".format( + name + ) + ) def _validate_params(self): - """ do some extra checks on input parameters """ + """do some extra checks on input parameters""" super(PFSenseSetupModule, self)._validate_params() params = self.params - if params.get('dashboardcolumns') is not None and (params['dashboardcolumns'] < 1 or params['dashboardcolumns'] > 6): - self.module.fail_json(msg='The submitted Dashboard Columns value is invalid.') - - if params.get('domain') is not None: - domain = params['domain'].lower() - groups = re.match(r'^(?:(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9])\.)*(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9\.])$', domain) + if params.get("dashboardcolumns") is not None and ( + params["dashboardcolumns"] < 1 or params["dashboardcolumns"] > 6 + ): + self.module.fail_json( + msg="The submitted Dashboard Columns value is invalid." + ) + + if params.get("domain") is not None: + domain = params["domain"].lower() + groups = re.match( + r"^(?:(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9])\.)*(?:[a-z_0-9]|[a-z_0-9][a-z_0-9\-]*[a-z_0-9\.])$", + domain, + ) if groups is None: - self.module.fail_json(msg="The domain may only contain the characters a-z, 0-9, '-' and '.'") + self.module.fail_json( + msg="The domain may only contain the characters a-z, 0-9, '-' and '.'" + ) - if params.get('hostname') is not None: - self._validate_hostname(params['hostname'], 'hostname', True) + if params.get("hostname") is not None: + self._validate_hostname(params["hostname"], "hostname", True) - if params.get('logincss') is not None: + if params.get("logincss") is not None: error = False try: - int(params['logincss'], 16) + int(params["logincss"], 16) except ValueError: error = True - if error or len(params['logincss']) != 6: - self.module.fail_json(msg="logincss must be a six digits hexadecimal string.") - - if params.get('timezone') is not None: - self._validate_timezone(params['timezone']) - - if params.get('timeservers') is not None: - for timeserver in params['timeservers'].split(' '): - self._validate_hostname(timeserver, 'timeserver') - - if params.get('authmode') is not None: - value = params.get('authmode') - if value != 'Local Database': - authserver_elt = self.pfsense.find_elt('authserver', value, search_field='name', root_elt=self.root_elt) + if error or len(params["logincss"]) != 6: + self.module.fail_json( + msg="logincss must be a six digits hexadecimal string." + ) + + if params.get("timezone") is not None: + self._validate_timezone(params["timezone"]) + + if params.get("timeservers") is not None: + for timeserver in params["timeservers"].split(" "): + self._validate_hostname(timeserver, "timeserver") + + if params.get("authmode") is not None: + value = params.get("authmode") + if value != "Local Database": + authserver_elt = self.pfsense.find_elt( + "authserver", value, search_field="name", root_elt=self.root_elt + ) if authserver_elt is None: - self.module.fail_json(msg="Given authserver '{0}' could not be found.".format(value)) + self.module.fail_json( + msg="Given authserver '{0}' could not be found.".format(value) + ) - if params.get('shellauth') is not None and params.get('shellauth') is True: - if authserver_elt.find('type').text == 'ldap': + if ( + params.get("shellauth") is not None + and params.get("shellauth") is True + ): + if authserver_elt.find("type").text == "ldap": # check if ldap_pam_groupdn is set - if authserver_elt.find('ldap_pam_groupdn') is None or \ - authserver_elt.find('ldap_pam_groupdn').text is None or \ - authserver_elt.find('ldap_pam_groupdn').text == '': - self.module.fail_json(msg="ldap_pam_groupdn not set for authserver '{0}'.".format(value)) + if ( + authserver_elt.find("ldap_pam_groupdn") is None + or authserver_elt.find("ldap_pam_groupdn").text is None + or authserver_elt.find("ldap_pam_groupdn").text == "" + ): + self.module.fail_json( + msg="ldap_pam_groupdn not set for authserver '{0}'.".format( + value + ) + ) # DNS ip_types = [] dns_addresses = [] - if params.get('dns_addresses') is not None: - dns_addresses = params['dns_addresses'].split() + if params.get("dns_addresses") is not None: + dns_addresses = params["dns_addresses"].split() for address in dns_addresses: if dns_addresses.count(address) > 1: - self.module.fail_json(msg='Each configured DNS server must have a unique IP address. Remove the duplicated IP.') + self.module.fail_json( + msg="Each configured DNS server must have a unique IP address. Remove the duplicated IP." + ) if self.pfsense.is_ipv4_address(address): ip_types.append(4) elif self.pfsense.is_ipv6_address(address): ip_types.append(6) else: - self.module.fail_json(msg='A valid IP address must be specified for DNS server {0}.'.format(address)) + self.module.fail_json( + msg="A valid IP address must be specified for DNS server {0}.".format( + address + ) + ) - if params.get('dns_hostnames') is not None: - for hostname in params['dns_hostnames'].split(' '): - if hostname != 'none': - self._validate_hostname(hostname, 'DNS hostname') + if params.get("dns_hostnames") is not None: + for hostname in params["dns_hostnames"].split(" "): + if hostname != "none": + self._validate_hostname(hostname, "DNS hostname") - if params.get('dns_gateways') is not None: - for idx, address in enumerate(params['dns_gateways'].split(' ')): - if idx >= len(dns_addresses) or address == 'none': + if params.get("dns_gateways") is not None: + for idx, address in enumerate(params["dns_gateways"].split(" ")): + if idx >= len(dns_addresses) or address == "none": continue - if self.pfsense.find_gateway_elt(address, protocol='inet') is not None: + if self.pfsense.find_gateway_elt(address, protocol="inet") is not None: if ip_types[idx] == 6: - self.module.fail_json(msg='The IPv4 gateway "{0}" can not be specified for IPv6 DNS server "{1}".'.format(address, dns_addresses[idx])) - elif self.pfsense.find_gateway_elt(address, protocol='inet6') is not None: + self.module.fail_json( + msg='The IPv4 gateway "{0}" can not be specified for IPv6 DNS server "{1}".'.format( + address, dns_addresses[idx] + ) + ) + elif ( + self.pfsense.find_gateway_elt(address, protocol="inet6") is not None + ): if ip_types[idx] == 4: - self.module.fail_json(msg='The IPv6 gateway "{0}" can not be specified for IPv4 DNS server "{1}".'.format(address, dns_addresses[idx])) + self.module.fail_json( + msg='The IPv6 gateway "{0}" can not be specified for IPv4 DNS server "{1}".'.format( + address, dns_addresses[idx] + ) + ) else: - self.module.fail_json(msg='The gateway "{0}" does not exist.'.format(address)) + self.module.fail_json( + msg='The gateway "{0}" does not exist.'.format(address) + ) if self.pfsense.is_within_local_networks(dns_addresses[idx]): self.module.fail_json( - msg="A gateway can not be assigned to DNS '{0}' server which is on a directly connected network.".format(dns_addresses[idx]) + msg="A gateway can not be assigned to DNS '{0}' server which is on a directly connected network.".format( + dns_addresses[idx] + ) ) def _validate_timezone(self, timezone): - """ check timezone """ - path = '/usr/share/zoneinfo/' - if not isfile(path + timezone) or timezone[:1] < 'A' or timezone[:1] > 'Z': - self.module.fail_json(msg='The submitted timezone is invalid') + """check timezone""" + path = "/usr/share/zoneinfo/" + if not isfile(path + timezone) or timezone[:1] < "A" or timezone[:1] > "Z": + self.module.fail_json(msg="The submitted timezone is invalid") ############################## # XML processing # def _get_params_to_remove(self): - """ returns the list of params to remove if they are not set """ + """returns the list of params to remove if they are not set""" to_remove = super(PFSenseSetupModule, self)._get_params_to_remove() to_remove.extend(self.params_to_delete) return to_remove @@ -496,11 +598,11 @@ def _get_params_to_remove(self): # run # def _update(self): - """ make the target pfsense reload """ + """make the target pfsense reload""" for cmd in self.route_cmds: self.module.run_command(cmd) - cmd = ''' + cmd = """ require_once("auth.inc"); require_once("filter.inc"); require_once("system_advanced_admin.inc"); @@ -514,17 +616,21 @@ def _update(self): $retval |= services_unbound_configure(); } $retval |= system_timezone_configure(); -$retval |= system_ntp_configure();''' - - if self.params.get('dnsallowoverride') is not None: - if (self.params['dnsallowoverride'] and 'dnsallowoverride' not in self.diff['before'] or - not self.params['dnsallowoverride'] and 'dnsallowoverride' in self.diff['before']): +$retval |= system_ntp_configure();""" + + if self.params.get("dnsallowoverride") is not None: + if ( + self.params["dnsallowoverride"] + and "dnsallowoverride" not in self.diff["before"] + or not self.params["dnsallowoverride"] + and "dnsallowoverride" in self.diff["before"] + ): cmd += '$retval |= send_event("service reload dns");\n' - if self.params.get('shellauth') is not None: - cmd += '$retval |= set_pam_auth();' + if self.params.get("shellauth") is not None: + cmd += "$retval |= set_pam_auth();" - cmd += '$retval |= filter_configure();\n' + cmd += "$retval |= filter_configure();\n" restart_webgui = False for param in ['ssl-certref']: @@ -541,107 +647,239 @@ def _update(self): # @staticmethod def _get_obj_name(): - """ return obj's name """ + """return obj's name""" return "general" def _log_fields(self, before=None): - """ generate pseudo-CLI command fields parameters to create an obj """ - bwebgui = self.diff['before']['webgui'] - webgui = self.obj['webgui'] + """generate pseudo-CLI command fields parameters to create an obj""" + bwebgui = self.diff["before"]["webgui"] + webgui = self.obj["webgui"] - obj_before = self._prepare_dns_log(self.diff['before']) + obj_before = self._prepare_dns_log(self.diff["before"]) obj_after = self._prepare_dns_log(self.obj) - values = '' - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'hostname', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'domain', add_comma=(values), log_none=False) - - values += self.format_updated_cli_field(obj_after, obj_before, 'dns_addresses', add_comma=(values), log_none=True) - values += self.format_updated_cli_field(obj_after, obj_before, 'dns_hostnames', add_comma=(values), log_none=True) - values += self.format_updated_cli_field(obj_after, obj_before, 'dns_gateways', add_comma=(values), log_none=True) - - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'timezone', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'timeservers', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'language', add_comma=(values), log_none=False) - - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'dnsallowoverride', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(self.obj, self.diff['before'], 'dnslocalhost', add_comma=(values), log_none=False) - - values += self.format_updated_cli_field(obj_after, obj_before, 'webguicert', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(obj_after, obj_before, 'webguicss', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'webguifixedmenu', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'webguihostnamemenu', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'dashboardcolumns', add_comma=(values), log_none=False) - - values += self.format_updated_cli_field(webgui, bwebgui, 'interfacessort', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'dashboardavailablewidgetspanel', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'systemlogsfilterpanel', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'systemlogsmanagelogpanel', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'statusmonitoringsettingspanel', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'requirestatefilter', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'webguileftcolumnhyper', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'disablealiaspopupdetail', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'roworderdragging', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'logincss', add_comma=(values), log_none=False) - values += self.format_updated_cli_field(webgui, bwebgui, 'loginshowhost', fvalue=self.fvalue_bool, add_comma=(values), log_none=False) + values = "" + values += self.format_updated_cli_field( + self.obj, + self.diff["before"], + "hostname", + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, self.diff["before"], "domain", add_comma=(values), log_none=False + ) + + values += self.format_updated_cli_field( + obj_after, obj_before, "dns_addresses", add_comma=(values), log_none=True + ) + values += self.format_updated_cli_field( + obj_after, obj_before, "dns_hostnames", add_comma=(values), log_none=True + ) + values += self.format_updated_cli_field( + obj_after, obj_before, "dns_gateways", add_comma=(values), log_none=True + ) + + values += self.format_updated_cli_field( + self.obj, + self.diff["before"], + "timezone", + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + self.diff["before"], + "timeservers", + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + self.diff["before"], + "language", + add_comma=(values), + log_none=False, + ) + + values += self.format_updated_cli_field( + self.obj, + self.diff["before"], + "dnsallowoverride", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + self.obj, + self.diff["before"], + "dnslocalhost", + add_comma=(values), + log_none=False, + ) + + values += self.format_updated_cli_field( + obj_after, obj_before, "webguicert", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + obj_after, obj_before, "webguicss", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "webguifixedmenu", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, bwebgui, "webguihostnamemenu", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + webgui, bwebgui, "dashboardcolumns", add_comma=(values), log_none=False + ) + + values += self.format_updated_cli_field( + webgui, + bwebgui, + "interfacessort", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "dashboardavailablewidgetspanel", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "systemlogsfilterpanel", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "systemlogsmanagelogpanel", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "statusmonitoringsettingspanel", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "requirestatefilter", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "webguileftcolumnhyper", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "disablealiaspopupdetail", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "roworderdragging", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) + values += self.format_updated_cli_field( + webgui, bwebgui, "logincss", add_comma=(values), log_none=False + ) + values += self.format_updated_cli_field( + webgui, + bwebgui, + "loginshowhost", + fvalue=self.fvalue_bool, + add_comma=(values), + log_none=False, + ) return values @staticmethod def _prepare_dns_log(obj): - """ construct dict for logging """ + """construct dict for logging""" ret = dict() - webgui = obj['webgui'] + webgui = obj["webgui"] - ret['webguicss'] = webgui['webguicss'].replace('.css', '') if 'webguicss' in webgui else None + ret["webguicss"] = ( + webgui["webguicss"].replace(".css", "") if "webguicss" in webgui else None + ) - if 'dnsserver' in obj: - ret['dns_addresses'] = ' '.join(obj['dnsserver']) + if "dnsserver" in obj: + ret["dns_addresses"] = " ".join(obj["dnsserver"]) else: - ret['dns_addresses'] = None + ret["dns_addresses"] = None - ret['dns_hostnames'] = None - ret['dns_gateways'] = None + ret["dns_hostnames"] = None + ret["dns_gateways"] = None idx = 1 hosts = list() gateways = list() while True: - host = 'dns{0}host'.format(idx) - gateway = 'dns{0}gw'.format(idx) + host = "dns{0}host".format(idx) + gateway = "dns{0}gw".format(idx) if host not in obj or gateway not in obj: break - hosts.append(obj[host] if obj[host] != '' else 'none') - gateways.append(obj[gateway] if obj[gateway] != '' else 'none') + hosts.append(obj[host] if obj[host] != "" else "none") + gateways.append(obj[gateway] if obj[gateway] != "" else "none") idx += 1 # we have multiple string that can give the same configuration # we remove the ending nones (assuming the user won't specify them for nothing) while True: - if len(hosts) and hosts[-1] == 'none': + if len(hosts) and hosts[-1] == "none": hosts.pop() continue - if len(gateways) and gateways[-1] == 'none': + if len(gateways) and gateways[-1] == "none": gateways.pop() continue break if len(hosts): - ret['dns_hostnames'] = ' '.join(hosts) + ret["dns_hostnames"] = " ".join(hosts) if len(gateways): - ret['dns_gateways'] = ' '.join(gateways) + ret["dns_gateways"] = " ".join(gateways) return ret def main(): - module = AnsibleModule( - argument_spec=SETUP_ARGUMENT_SPEC, - supports_check_mode=True) + module = AnsibleModule(argument_spec=SETUP_ARGUMENT_SPEC, supports_check_mode=True) pfmodule = PFSenseSetupModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_shellcmd.py b/plugins/modules/pfsense_shellcmd.py index 4649cce5..1995ecf2 100644 --- a/plugins/modules/pfsense_shellcmd.py +++ b/plugins/modules/pfsense_shellcmd.py @@ -4,10 +4,11 @@ # Copyright: (c) 2024, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type -DOCUMENTATION = r''' +DOCUMENTATION = r""" --- module: pfsense_shellcmd @@ -38,9 +39,9 @@ type: str author: Orion Poplawski (@opoplawski) -''' +""" -EXAMPLES = r''' +EXAMPLES = r""" - name: Add myitem shellcmd pfsensible.core.pfsense_shellcmd: description: myitem @@ -52,45 +53,50 @@ pfsensible.core.pfsense_shellcmd: description: myitem state: absent -''' -RETURN = r''' +""" +RETURN = r""" commands: description: the set of commands that would be pushed to the remote device (if pfSense had a CLI) returned: always type: list sample: ["create shellcmd 'myitem'", "update shellcmd 'myitem' set ...", "delete shellcmd 'myitem'"] -''' +""" from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) # Compact style SHELLCMD_ARGUMENT_SPEC = dict( # Only description should be required here - othewise you cannot remove an item with just 'description' # Required arguments for creation should be noted in SHELLCMD_REQUIRED_IF = ['state', 'present', ...] below - description=dict(required=True, type='str'), - state=dict(type='str', default='present', choices=['present', 'absent']), - cmd=dict(type='str'), - cmdtype=dict(type='str', choices=['shellcmd', 'earlyshellcmd', 'afterfilterchangeshellcmd', 'disabled'],), + description=dict(required=True, type="str"), + state=dict(type="str", default="present", choices=["present", "absent"]), + cmd=dict(type="str"), + cmdtype=dict( + type="str", + choices=["shellcmd", "earlyshellcmd", "afterfilterchangeshellcmd", "disabled"], + ), ) SHELLCMD_REQUIRED_IF = [ - ['state', 'present', ['cmd']], + ["state", "present", ["cmd"]], ] # default values when creating a new shellcmd SHELLCMD_CREATE_DEFAULT = dict( - cmdtype='shellcmd', + cmdtype="shellcmd", ) -SHELLCMD_PHP_COMMAND_SET = r''' +SHELLCMD_PHP_COMMAND_SET = r""" require_once("shellcmd.inc"); shellcmd_sync_package(); -''' +""" class PFSenseShellcmdModule(PFSenseModuleBase): - """ module managing pfsense shellcmds """ + """module managing pfsense shellcmds""" ############################## # unit tests @@ -98,27 +104,47 @@ class PFSenseShellcmdModule(PFSenseModuleBase): # Must be class method for unit test usage @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return SHELLCMD_ARGUMENT_SPEC def __init__(self, module, pfsense=None): - super(PFSenseShellcmdModule, self).__init__(module, pfsense, package='shellcmd', root='shellcmdsettings', node='config', key='description', - update_php=SHELLCMD_PHP_COMMAND_SET, create_default=SHELLCMD_CREATE_DEFAULT) + super(PFSenseShellcmdModule, self).__init__( + module, + pfsense, + package="shellcmd", + root="shellcmdsettings", + node="config", + key="description", + update_php=SHELLCMD_PHP_COMMAND_SET, + create_default=SHELLCMD_CREATE_DEFAULT, + ) ############################## # XML processing # def _find_target(self): - """ find the XML target_elt """ + """find the XML target_elt""" # There can be only one 'afterfilterchangeshellcmd' shellcmd - if self.params['cmdtype'] == 'afterfilterchangeshellcmd': - result = self.root_elt.findall("{node}[{key}='{value}']".format(node=self.node, key='cmdtype', value='afterfilterchangeshellcmd')) + if self.params["cmdtype"] == "afterfilterchangeshellcmd": + result = self.root_elt.findall( + "{node}[{key}='{value}']".format( + node=self.node, key="cmdtype", value="afterfilterchangeshellcmd" + ) + ) else: - result = self.root_elt.findall("{node}[{key}='{value}']".format(node=self.node, key=self.key, value=self.obj[self.key])) + result = self.root_elt.findall( + "{node}[{key}='{value}']".format( + node=self.node, key=self.key, value=self.obj[self.key] + ) + ) if len(result) == 1: return result[0] elif len(result) > 1: - self.module.fail_json(msg='Found multiple {node}s for {key} {value}.'.format(node=self.node, key=self.key, value=self.obj[self.key])) + self.module.fail_json( + msg="Found multiple {node}s for {key} {value}.".format( + node=self.node, key=self.key, value=self.obj[self.key] + ) + ) else: return None @@ -127,7 +153,8 @@ def main(): module = AnsibleModule( argument_spec=SHELLCMD_ARGUMENT_SPEC, required_if=SHELLCMD_REQUIRED_IF, - supports_check_mode=True) + supports_check_mode=True, + ) pfmodule = PFSenseShellcmdModule(module) # Pass params for testing framework @@ -135,5 +162,5 @@ def main(): pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_user.py b/plugins/modules/pfsense_user.py index 06bbaecb..cca6e1e0 100644 --- a/plugins/modules/pfsense_user.py +++ b/plugins/modules/pfsense_user.py @@ -4,10 +4,11 @@ # Copyright: (c) 2019-2024, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type -DOCUMENTATION = r''' +DOCUMENTATION = r""" --- module: pfsense_user version_added: 0.1.0 @@ -60,9 +61,9 @@ default: false type: bool version_added: 0.7.1 -''' +""" -EXAMPLES = r''' +EXAMPLES = r""" - name: Add operator user pfsense_user: name: operator @@ -75,37 +76,39 @@ pfsense_user: name: operator state: absent -''' +""" -RETURN = r''' +RETURN = r""" -''' +""" import base64 import re from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.module_base import PFSenseModuleBase +from ansible_collections.pfsensible.core.plugins.module_utils.module_base import ( + PFSenseModuleBase, +) USER_ARGUMENT_SPEC = dict( - name=dict(required=True, type='str'), - state=dict(type='str', default='present', choices=['present', 'absent']), - descr=dict(type='str'), - scope=dict(type='str', choices=['user', 'system']), - uid=dict(type='str'), - password=dict(type='str', no_log=True), - groups=dict(type='list', elements='str'), - priv=dict(type='list', elements='str'), - authorizedkeys=dict(type='str'), - disabled=dict(type='bool', default=False), + name=dict(required=True, type="str"), + state=dict(type="str", default="present", choices=["present", "absent"]), + descr=dict(type="str"), + scope=dict(type="str", choices=["user", "system"]), + uid=dict(type="str"), + password=dict(type="str", no_log=True), + groups=dict(type="list", elements="str"), + priv=dict(type="list", elements="str"), + authorizedkeys=dict(type="str"), + disabled=dict(type="bool", default=False), ) USER_CREATE_DEFAULT = dict( - scope='user', + scope="user", ) USER_MAP_PARAM = [ - ('password', 'bcrypt-hash'), + ("password", "bcrypt-hash"), ] @@ -117,13 +120,15 @@ def parse_groups(self, name, params, obj): def p2o_ssh_pub_key(self, name, params, obj): # Allow ssh keys to be clear or base64 encoded - if params[name] is not None and 'ssh-' in params[name]: + if params[name] is not None and "ssh-" in params[name]: obj[name] = base64.b64encode(params[name].encode()).decode() def validate_password(self, password): - if not re.match(r'\$2[aby]\$', str(password)): - raise ValueError('Password (%s) does not appear to be a bcrypt hash' % (password)) + if not re.match(r"\$2[aby]\$", str(password)): + raise ValueError( + "Password (%s) does not appear to be a bcrypt hash" % (password) + ) USER_ARG_ROUTE = dict( @@ -138,7 +143,9 @@ def validate_password(self, password): $group_config = config_get_path('system/group'); """ -USER_PHP_COMMAND_SET = USER_PHP_COMMAND_PREFIX + """ +USER_PHP_COMMAND_SET = ( + USER_PHP_COMMAND_PREFIX + + """ $userent = config_get_path('system/user')[{idx}]; local_user_set($userent); foreach ({mod_groups} as $groupname) {{ @@ -149,9 +156,12 @@ def validate_password(self, password): run_plugins("/etc/inc/privhooks"); }} """ +) # This runs after we remove the group from the config so we can't use $config -USER_PHP_COMMAND_DEL = USER_PHP_COMMAND_PREFIX + """ +USER_PHP_COMMAND_DEL = ( + USER_PHP_COMMAND_PREFIX + + """ $userent['name'] = '{name}'; $userent['uid'] = {uid}; foreach ({mod_groups} as $groupname) {{ @@ -160,10 +170,11 @@ def validate_password(self, password): }} local_user_del($userent); """ +) class PFSenseUserModule(PFSenseModuleBase): - """ module managing pfsense users """ + """module managing pfsense users""" ############################## # unit tests @@ -171,13 +182,21 @@ class PFSenseUserModule(PFSenseModuleBase): # Must be class method for unit test usage @staticmethod def get_argument_spec(): - """ return argument spec """ + """return argument spec""" return USER_ARGUMENT_SPEC def __init__(self, module, pfsense=None): - super(PFSenseUserModule, self).__init__(module, pfsense, root='system', node='user', key='name', - arg_route=USER_ARG_ROUTE, map_param=USER_MAP_PARAM, create_default=USER_CREATE_DEFAULT) - self.groups = self.root_elt.findall('group') + super(PFSenseUserModule, self).__init__( + module, + pfsense, + root="system", + node="user", + key="name", + arg_route=USER_ARG_ROUTE, + map_param=USER_MAP_PARAM, + create_default=USER_CREATE_DEFAULT, + ) + self.groups = self.root_elt.findall("group") self.user_groups = None self.mod_groups = [] @@ -185,16 +204,24 @@ def __init__(self, module, pfsense=None): # XML processing # def _find_group_elt(self, name): - return self.pfsense.find_elt('group', name, search_field='name', root_elt=self.root_elt) + return self.pfsense.find_elt( + "group", name, search_field="name", root_elt=self.root_elt + ) def _find_group_names_for_uid(self, uid): groups = [] - for group_elt in self.pfsense.find_elt("group", uid, search_field="member", root_elt=self.root_elt, multiple_ok=True): - groups.append(group_elt.find('name').text) + for group_elt in self.pfsense.find_elt( + "group", + uid, + search_field="member", + root_elt=self.root_elt, + multiple_ok=True, + ): + groups.append(group_elt.find("name").text) return groups def _nextuid(self): - nextuid_elt = self.root_elt.find('nextuid') + nextuid_elt = self.root_elt.find("nextuid") nextuid = nextuid_elt.text nextuid_elt.text = str(int(nextuid) + 1) return nextuid @@ -206,14 +233,14 @@ def _format_diff_priv(self, priv): return priv def _copy_and_add_target(self): - """ populate the XML target_elt """ + """populate the XML target_elt""" obj = self.obj - if 'bcrypt-hash' not in obj: - self.module.fail_json(msg='Password is required when adding a user') - if 'uid' not in obj: - obj['uid'] = self._nextuid() + if "bcrypt-hash" not in obj: + self.module.fail_json(msg="Password is required when adding a user") + if "uid" not in obj: + obj["uid"] = self._nextuid() - self.diff['after'] = obj + self.diff["after"] = obj self.pfsense.copy_dict_to_element(self.obj, self.target_elt) self._update_groups() self.root_elt.insert(self._find_last_element_index(), self.target_elt) @@ -221,15 +248,17 @@ def _copy_and_add_target(self): self.elements = self.root_elt.findall(self.node) def _copy_and_update_target(self): - """ update the XML target_elt """ + """update the XML target_elt""" before = self.pfsense.element_to_dict(self.target_elt) - self.diff['before'] = before - if 'priv' in before: - before['priv'] = self._format_diff_priv(before['priv']) + self.diff["before"] = before + if "priv" in before: + before["priv"] = self._format_diff_priv(before["priv"]) changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) - if 'priv' in self.diff['after']: - self.diff['after']['priv'] = self._format_diff_priv(self.diff['after']['priv']) + self.diff["after"] = self.pfsense.element_to_dict(self.target_elt) + if "priv" in self.diff["after"]: + self.diff["after"]["priv"] = self._format_diff_priv( + self.diff["after"]["priv"] + ) if self._remove_deleted_disabled_param(): changed = True if self._update_groups(): @@ -244,27 +273,27 @@ def _update_groups(self): # Only modify group membership is groups was specified if self.user_groups is not None: # Handle group member element - need uid set or retrieved above - uid = self.target_elt.find('uid').text + uid = self.target_elt.find("uid").text # Get current group membership - self.diff['before']['groups'] = self._find_group_names_for_uid(uid) + self.diff["before"]["groups"] = self._find_group_names_for_uid(uid) # Add user to groups if needed for group in self.user_groups: group_elt = self._find_group_elt(group) if group_elt is None: - self.module.fail_json(msg='Group (%s) does not exist' % group) + self.module.fail_json(msg="Group (%s) does not exist" % group) if len(group_elt.findall("[member='{0}']".format(uid))) == 0: changed = True self.mod_groups.append(group) - group_elt.append(self.pfsense.new_element('member', uid)) + group_elt.append(self.pfsense.new_element("member", uid)) # Remove user from groups if needed - for group in self.diff['before']['groups']: + for group in self.diff["before"]["groups"]: if group not in self.user_groups: group_elt = self._find_group_elt(group) if group_elt is None: - self.module.fail_json(msg='Group (%s) does not exist' % group) - for member_elt in group_elt.findall('member'): + self.module.fail_json(msg="Group (%s) does not exist" % group) + for member_elt in group_elt.findall("member"): if member_elt.text == uid: changed = True self.mod_groups.append(group) @@ -272,20 +301,24 @@ def _update_groups(self): break # Groups are not stored in the user element - self.diff['after']['groups'] = self.user_groups + self.diff["after"]["groups"] = self.user_groups # Decode keys for diff for k in self.diff: - if 'authorizedkeys' in self.diff[k]: - self.diff[k]['authorizedkeys'] = base64.b64decode(self.diff[k]['authorizedkeys']) + if "authorizedkeys" in self.diff[k]: + self.diff[k]["authorizedkeys"] = base64.b64decode( + self.diff[k]["authorizedkeys"] + ) return changed def _remove_deleted_disabled_param(self): - """ Remove disabled param if user is re-enabled """ + """Remove disabled param if user is re-enabled""" changed = False - if self.pfsense.remove_deleted_param_from_elt(self.target_elt, 'disabled', self.obj): + if self.pfsense.remove_deleted_param_from_elt( + self.target_elt, "disabled", self.obj + ): changed = True return changed @@ -294,37 +327,47 @@ def _remove_deleted_disabled_param(self): # run # def _update(self): - if self.params['state'] == 'present': - return self.pfsense.phpshell(USER_PHP_COMMAND_SET.format(idx=self._find_this_element_index(), mod_groups=self.mod_groups)) + if self.params["state"] == "present": + return self.pfsense.phpshell( + USER_PHP_COMMAND_SET.format( + idx=self._find_this_element_index(), mod_groups=self.mod_groups + ) + ) else: - return self.pfsense.phpshell(USER_PHP_COMMAND_DEL.format(name=self.obj['name'], uid=self.obj['uid'], mod_groups=self.mod_groups)) + return self.pfsense.phpshell( + USER_PHP_COMMAND_DEL.format( + name=self.obj["name"], + uid=self.obj["uid"], + mod_groups=self.mod_groups, + ) + ) def _pre_remove_target_elt(self): - self.diff['after'] = {} + self.diff["after"] = {} if self.target_elt is not None: - self.diff['before'] = self.pfsense.element_to_dict(self.target_elt) + self.diff["before"] = self.pfsense.element_to_dict(self.target_elt) # Store uid for _update() - self.obj['uid'] = self.target_elt.find('uid').text + self.obj["uid"] = self.target_elt.find("uid").text # Get current group membership - self.diff['before']['groups'] = self._find_group_names_for_uid(self.obj['uid']) + self.diff["before"]["groups"] = self._find_group_names_for_uid( + self.obj["uid"] + ) # Remove user from groups if needed - for group in self.diff['before']['groups']: + for group in self.diff["before"]["groups"]: group_elt = self._find_group_elt(group) if group_elt is None: - self.module.fail_json(msg='Group (%s) does not exist' % group) - for member_elt in group_elt.findall('member'): - if member_elt.text == self.obj['uid']: + self.module.fail_json(msg="Group (%s) does not exist" % group) + for member_elt in group_elt.findall("member"): + if member_elt.text == self.obj["uid"]: self.mod_groups.append(group) group_elt.remove(member_elt) break def main(): - module = AnsibleModule( - argument_spec=USER_ARGUMENT_SPEC, - supports_check_mode=True) + module = AnsibleModule(argument_spec=USER_ARGUMENT_SPEC, supports_check_mode=True) pfmodule = PFSenseUserModule(module) # Pass params for testing framework @@ -332,5 +375,5 @@ def main(): pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/plugins/modules/pfsense_vlan.py b/plugins/modules/pfsense_vlan.py index 0ee871ca..64cc4191 100644 --- a/plugins/modules/pfsense_vlan.py +++ b/plugins/modules/pfsense_vlan.py @@ -6,11 +6,14 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import absolute_import, division, print_function + __metaclass__ = type -ANSIBLE_METADATA = {'metadata_version': '1.1', - 'status': ['preview'], - 'supported_by': 'community'} +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "community", +} DOCUMENTATION = """ --- @@ -70,18 +73,19 @@ """ from ansible.module_utils.basic import AnsibleModule -from ansible_collections.pfsensible.core.plugins.module_utils.vlan import PFSenseVlanModule, VLAN_ARGUMENT_SPEC +from ansible_collections.pfsensible.core.plugins.module_utils.vlan import ( + PFSenseVlanModule, + VLAN_ARGUMENT_SPEC, +) def main(): - module = AnsibleModule( - argument_spec=VLAN_ARGUMENT_SPEC, - supports_check_mode=True) + module = AnsibleModule(argument_spec=VLAN_ARGUMENT_SPEC, supports_check_mode=True) pfmodule = PFSenseVlanModule(module) pfmodule.run(module.params) pfmodule.commit_changes() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..caac6872 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.black] +line-length = 160 +# Normalize string quotes to a consistent style (double quotes) +skip-string-normalization = false diff --git a/tests/unit/plugins/lookup/test_pfsense.py b/tests/unit/plugins/lookup/test_pfsense.py index b908d5da..17484b14 100644 --- a/tests/unit/plugins/lookup/test_pfsense.py +++ b/tests/unit/plugins/lookup/test_pfsense.py @@ -1,18 +1,23 @@ # Copyright: (c) 2020, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type from collections import OrderedDict import yaml -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) from ansible.plugins.loader import lookup_loader -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ModuleTestCase +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + ModuleTestCase, +) def ordered_dump(data, dumper_cls=yaml.Dumper): - """ dump and return yaml string from data using ordered dicts """ + """dump and return yaml string from data using ordered dicts""" class OrderedDumper(dumper_cls): pass @@ -37,196 +42,229 @@ def __init__(self, *args, **kwargs): self.interfaces = None def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseLookup, self).setUp() self.build_definitions() - self.mock_get_hostname = patch('ansible_collections.pfsensible.core.plugins.lookup.pfsense.LookupModule.get_hostname') + self.mock_get_hostname = patch( + "ansible_collections.pfsensible.core.plugins.lookup.pfsense.LookupModule.get_hostname" + ) get_hostname = self.mock_get_hostname.start() - get_hostname.return_value = ('pf_test1') + get_hostname.return_value = "pf_test1" - self.mock_get_definitions = patch('ansible_collections.pfsensible.core.plugins.lookup.pfsense.LookupModule.get_definitions') + self.mock_get_definitions = patch( + "ansible_collections.pfsensible.core.plugins.lookup.pfsense.LookupModule.get_definitions" + ) self.get_definitions = self.mock_get_definitions.start() self.get_definitions.return_value = self.definitions def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseLookup, self).tearDown() self.mock_get_hostname.stop() self.mock_get_definitions.stop() def build_definitions(self): - """ build definitions base for tests """ + """build definitions base for tests""" self.definitions = OrderedDict() - self.definitions['hosts_aliases'] = OrderedDict() - self.definitions['ports_aliases'] = OrderedDict() - self.definitions['rules'] = OrderedDict() - self.definitions['pfsenses'] = OrderedDict() - self.definitions['pfsenses']['pf_test1'] = OrderedDict() - self.definitions['pfsenses']['pf_test1'] = OrderedDict() - self.definitions['pfsenses']['pf_test1']['interfaces'] = OrderedDict() + self.definitions["hosts_aliases"] = OrderedDict() + self.definitions["ports_aliases"] = OrderedDict() + self.definitions["rules"] = OrderedDict() + self.definitions["pfsenses"] = OrderedDict() + self.definitions["pfsenses"]["pf_test1"] = OrderedDict() + self.definitions["pfsenses"]["pf_test1"] = OrderedDict() + self.definitions["pfsenses"]["pf_test1"]["interfaces"] = OrderedDict() self.interfaces = dict( - WAN=dict(remote_networks='0.0.0.0/0'), - LANA=dict(base='10.20.30.x', remote_base='10.120.x', adjacent_base='10.220.x'), - LANB=dict(base='10.20.40.x', remote_base='10.130.x', adjacent_base='10.230.x'), + WAN=dict(remote_networks="0.0.0.0/0"), + LANA=dict( + base="10.20.30.x", remote_base="10.120.x", adjacent_base="10.220.x" + ), + LANB=dict( + base="10.20.40.x", remote_base="10.130.x", adjacent_base="10.230.x" + ), ) for name, defs in self.interfaces.items(): - self.definitions['pfsenses']['pf_test1']['interfaces'][name] = OrderedDict() - if 'base' in defs: - self.definitions['pfsenses']['pf_test1']['interfaces'][name]['ip'] = defs['base'].replace('x', '1/24') - for param in ['remote_networks', 'adjacent_networks']: + self.definitions["pfsenses"]["pf_test1"]["interfaces"][name] = OrderedDict() + if "base" in defs: + self.definitions["pfsenses"]["pf_test1"]["interfaces"][name]["ip"] = ( + defs["base"].replace("x", "1/24") + ) + for param in ["remote_networks", "adjacent_networks"]: if param in defs: - self.definitions['pfsenses']['pf_test1']['interfaces'][name][param] = defs[param] - if 'remote_base' in defs: - self.definitions['pfsenses']['pf_test1']['interfaces'][name]['remote_networks'] = defs['remote_base'].replace('x', '0.0/16') - if 'adjacent_base' in defs: - self.definitions['pfsenses']['pf_test1']['interfaces'][name]['adjacent_networks'] = defs['adjacent_base'].replace('x', '0.0/16') - - def save_definitions(self, filename='test_definitions.yml'): - """ save generated definitions to file for debbuging """ - with open(filename, 'w') as outfile: + self.definitions["pfsenses"]["pf_test1"]["interfaces"][name][ + param + ] = defs[param] + if "remote_base" in defs: + self.definitions["pfsenses"]["pf_test1"]["interfaces"][name][ + "remote_networks" + ] = defs["remote_base"].replace("x", "0.0/16") + if "adjacent_base" in defs: + self.definitions["pfsenses"]["pf_test1"]["interfaces"][name][ + "adjacent_networks" + ] = defs["adjacent_base"].replace("x", "0.0/16") + + def save_definitions(self, filename="test_definitions.yml"): + """save generated definitions to file for debbuging""" + with open(filename, "w") as outfile: outfile.write(ordered_dump(self.definitions)) def run_rules(self): - """ run the plugin for rules """ - pfsense_lookup = lookup_loader.get('pfsensible.core.pfsense') - self.rules = pfsense_lookup.run(['dummy.yml', 'rules'], {})[0] + """run the plugin for rules""" + pfsense_lookup = lookup_loader.get("pfsensible.core.pfsense") + self.rules = pfsense_lookup.run(["dummy.yml", "rules"], {})[0] def assert_get_rule(self, rule_name, count=1): - """ check that rule_name is defined """ + """check that rule_name is defined""" rules = [] for rule in self.rules: - if rule['name'] == rule_name: + if rule["name"] == rule_name: rules.append(rule) if count == 1 and len(rules) == 0: - self.fail('{0} not found'.format(rule_name)) + self.fail("{0} not found".format(rule_name)) if count == 1 and len(rules) > 1: - self.fail('Multiples {0} found: {1}'.format(rule_name, rules)) + self.fail("Multiples {0} found: {1}".format(rule_name, rules)) self.assertEqual(len(rules), count) if count == 1: return rules[0] return rules def assert_rule_not_found(self, rule_name): - """ check that rule_name is not defined """ + """check that rule_name is not defined""" for rule in self.rules: - if rule['name'] == rule_name: - self.fail('{0} found'.format(rule_name)) + if rule["name"] == rule_name: + self.fail("{0} found".format(rule_name)) @staticmethod def add_missing_fields(expected_rule, rule): - """ add missing generated field with default values """ - for param in ['ackqueue', 'gateway', 'icmptype', 'in_queue', 'out_queue', 'queue', 'log', 'sched']: + """add missing generated field with default values""" + for param in [ + "ackqueue", + "gateway", + "icmptype", + "in_queue", + "out_queue", + "queue", + "log", + "sched", + ]: if param not in expected_rule and param in rule: expected_rule[param] = None - if 'action' not in expected_rule: - expected_rule['action'] = 'pass' + if "action" not in expected_rule: + expected_rule["action"] = "pass" - if 'state' not in expected_rule: - expected_rule['state'] = 'present' + if "state" not in expected_rule: + expected_rule["state"] = "present" @staticmethod def correct_aliases(expected_rule): - """ we correct IP values with interface names """ + """we correct IP values with interface names""" translations = { - '10.20.30.1': 'IP:LANA', + "10.20.30.1": "IP:LANA", # '10.20.30.3': 'IP:LANB', } - for field in ['source', 'destination']: + for field in ["source", "destination"]: if expected_rule[field] in translations: expected_rule[field] = translations[expected_rule[field]] def compare_rules(self, expected_rule, rule): - """ compare rule with the expected result """ - if 'after' in rule: - del rule['after'] + """compare rule with the expected result""" + if "after" in rule: + del rule["after"] self.add_missing_fields(expected_rule, rule) self.correct_aliases(expected_rule) self.assertEqual(expected_rule, rule) def gen_rule(self, src, dst, interface, action): - """ generate rule definition according parameters """ + """generate rule definition according parameters""" rule = OrderedDict() - rule['protocol'] = 'any' - rule['name'] = src + '_' + dst + '_' + interface + '_' + action - if src == 'l': - rule['src'] = self.interfaces['LANA']['base'].replace('x', '2') - elif src == 's': - rule['src'] = self.interfaces['LANA']['base'].replace('x', '1') - elif src == 'r': - rule['src'] = self.interfaces['LANA']['remote_base'].replace('x', '30.30') - elif src == 'a': - rule['src'] = self.interfaces['LANA']['adjacent_base'].replace('x', '30.30') - - if interface == 's': - if dst == 'l': - rule['dst'] = self.interfaces['LANA']['base'].replace('x', '3') - elif dst == 's': - rule['dst'] = self.interfaces['LANA']['base'].replace('x', '1') - elif dst == 'r': - rule['dst'] = self.interfaces['LANA']['remote_base'].replace('x', '30.40') - elif dst == 'a': - rule['dst'] = self.interfaces['LANA']['adjacent_base'].replace('x', '30.40') + rule["protocol"] = "any" + rule["name"] = src + "_" + dst + "_" + interface + "_" + action + if src == "l": + rule["src"] = self.interfaces["LANA"]["base"].replace("x", "2") + elif src == "s": + rule["src"] = self.interfaces["LANA"]["base"].replace("x", "1") + elif src == "r": + rule["src"] = self.interfaces["LANA"]["remote_base"].replace("x", "30.30") + elif src == "a": + rule["src"] = self.interfaces["LANA"]["adjacent_base"].replace("x", "30.30") + + if interface == "s": + if dst == "l": + rule["dst"] = self.interfaces["LANA"]["base"].replace("x", "3") + elif dst == "s": + rule["dst"] = self.interfaces["LANA"]["base"].replace("x", "1") + elif dst == "r": + rule["dst"] = self.interfaces["LANA"]["remote_base"].replace( + "x", "30.40" + ) + elif dst == "a": + rule["dst"] = self.interfaces["LANA"]["adjacent_base"].replace( + "x", "30.40" + ) else: - if dst == 'l': - rule['dst'] = self.interfaces['LANB']['base'].replace('x', '3') - elif dst == 's': - rule['dst'] = self.interfaces['LANB']['base'].replace('x', '1') - elif dst == 'r': - rule['dst'] = self.interfaces['LANB']['remote_base'].replace('x', '30.40') - elif dst == 'a': - rule['dst'] = self.interfaces['LANB']['adjacent_base'].replace('x', '30.40') - - if action == 'p': - rule['action'] = 'pass' - elif action == 'dr': - rule['action'] = 'drop' - elif action == 'dn': - rule['action'] = 'deny' + if dst == "l": + rule["dst"] = self.interfaces["LANB"]["base"].replace("x", "3") + elif dst == "s": + rule["dst"] = self.interfaces["LANB"]["base"].replace("x", "1") + elif dst == "r": + rule["dst"] = self.interfaces["LANB"]["remote_base"].replace( + "x", "30.40" + ) + elif dst == "a": + rule["dst"] = self.interfaces["LANB"]["adjacent_base"].replace( + "x", "30.40" + ) + + if action == "p": + rule["action"] = "pass" + elif action == "dr": + rule["action"] = "drop" + elif action == "dn": + rule["action"] = "deny" return rule def test_basic_generation(self): - """ test simple rules generatation for verifying that remote to remote rules are not generated and almost everything else is """ + """test simple rules generatation for verifying that remote to remote rules are not generated and almost everything else is""" expected_rules = list() not_expected_rules = list() - rules = self.definitions['rules'] + rules = self.definitions["rules"] # we want to generate some rules to check # l => local, r => remote, a => adjacent, s => self # s => same interface, o => other interface # p => pass, dr => drop, dn => deny - for src in ['l', 'r', 'a', 's']: - for dst in ['l', 'r', 'a', 's']: - for interface in ['s', 'o']: - for action in ['p', 'dr', 'dn']: + for src in ["l", "r", "a", "s"]: + for dst in ["l", "r", "a", "s"]: + for interface in ["s", "o"]: + for action in ["p", "dr", "dn"]: rule = self.gen_rule(src, dst, interface, action) - rules[rule['name']] = rule + rules[rule["name"]] = rule generated_rule = dict( - name=rule['name'], - interface='LANA', - source=rule['src'], - destination=rule['dst'], - protocol='any', - action=rule['action'] + name=rule["name"], + interface="LANA", + source=rule["src"], + destination=rule["dst"], + protocol="any", + action=rule["action"], ) # we won't generate remote to remote rules or local to local on the same interface if the traffic is allowed # when the traffic is denied or dropped, we consider for now that every rule should be generated, even if it's seems dumb - if rule['name'] in ['r_r_s_p', 'r_r_o_p', 'l_l_s_p']: + if rule["name"] in ["r_r_s_p", "r_r_o_p", "l_l_s_p"]: not_expected_rules.append(generated_rule) else: expected_rules.append(generated_rule) - del rule['name'] + del rule["name"] self.run_rules() for expected_rule in expected_rules: - rule = self.assert_get_rule(expected_rule['name']) + rule = self.assert_get_rule(expected_rule["name"]) self.compare_rules(expected_rule, rule) for rule in not_expected_rules: - self.assert_rule_not_found(rule['name']) + self.assert_rule_not_found(rule["name"]) diff --git a/tests/unit/plugins/module_utils/test_pfsense.py b/tests/unit/plugins/module_utils/test_pfsense.py index ac2309ff..0aa1c0cc 100644 --- a/tests/unit/plugins/module_utils/test_pfsense.py +++ b/tests/unit/plugins/module_utils/test_pfsense.py @@ -1,23 +1,32 @@ # Copyright: (c) 2022, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch -from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import PFSenseModule -from ansible_collections.pfsensible.core.tests.unit.plugins.modules.pfsense_module import TestPFSenseModule +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) +from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import ( + PFSenseModule, +) +from ansible_collections.pfsensible.core.tests.unit.plugins.modules.pfsense_module import ( + TestPFSenseModule, +) class TestPFSense(TestPFSenseModule): - def __init__(self, *args, **kwargs): super(TestPFSense, self).__init__(*args, **kwargs) def setUp(self): super(TestPFSense, self).setUp() self.pfsense = PFSenseModule(None) - self.mock_get_version = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.get_version', wraps=self.my_get_version) + self.mock_get_version = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.get_version", + wraps=self.my_get_version, + ) self.get_version = self.mock_get_version.start() def tearDown(self): @@ -29,7 +38,7 @@ def my_get_version(self): def test_is_version(self): self.pfsense.pfsense_version = None - self.version = '2.6.0' + self.version = "2.6.0" assert self.pfsense.is_version([2, 5, 0]) assert self.pfsense.is_version([2, 6, 0]) assert not self.pfsense.is_version([2, 7, 0]) @@ -37,7 +46,7 @@ def test_is_version(self): assert not self.pfsense.is_version([2, 5, 0], or_more=False) assert not self.pfsense.is_version([21, 2]) self.pfsense.pfsense_version = None - self.version = '22.02' + self.version = "22.02" assert not self.pfsense.is_version([2, 6, 0]) assert not self.pfsense.is_version([2, 7, 0]) assert self.pfsense.is_version([21, 1]) @@ -49,8 +58,8 @@ def test_is_version(self): def test_is_at_least_2_5_0(self): self.pfsense.pfsense_version = None - self.version = '2.6.0' + self.version = "2.6.0" assert self.pfsense.is_at_least_2_5_0() self.pfsense.pfsense_version = None - self.version = '22.01' + self.version = "22.01" assert self.pfsense.is_at_least_2_5_0() diff --git a/tests/unit/plugins/modules/pfsense_module.py b/tests/unit/plugins/modules/pfsense_module.py index 6f1eb01f..56fb72c8 100644 --- a/tests/unit/plugins/modules/pfsense_module.py +++ b/tests/unit/plugins/modules/pfsense_module.py @@ -3,7 +3,8 @@ # Copyright: (c) 2024, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import os @@ -11,15 +12,23 @@ import json import re -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import AnsibleExitJson, AnsibleFailJson, ModuleTestCase -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import set_module_args +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + AnsibleExitJson, + AnsibleFailJson, + ModuleTestCase, +) +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + set_module_args, +) from tempfile import mkstemp import xml.etree.ElementTree as ET from xml.etree.ElementTree import fromstring, ElementTree -fixture_path = os.path.join(os.path.dirname(__file__), 'fixtures') +fixture_path = os.path.join(os.path.dirname(__file__), "fixtures") fixture_data = {} @@ -53,39 +62,53 @@ def __init__(self, *args, **kwargs): self.pfmodule = None def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseModule, self).setUp() - self.mock_parse = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.ET.parse') + self.mock_parse = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.ET.parse" + ) self.parse = self.mock_parse.start() - self.mock_shutil_move = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.shutil.move') + self.mock_shutil_move = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.shutil.move" + ) self.shutil_move = self.mock_shutil_move.start() - self.mock_php = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php') + self.mock_php = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php" + ) self.php = self.mock_php.start() - self.php.return_value = ['vmx0', 'vmx1', 'vmx2', 'vmx3'] + self.php.return_value = ["vmx0", "vmx1", "vmx2", "vmx3"] - self.mock_phpshell = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.phpshell') + self.mock_phpshell = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.phpshell" + ) self.phpshell = self.mock_phpshell.start() - self.phpshell.return_value = (0, '', '') + self.phpshell.return_value = (0, "", "") - self.mock_mkstemp = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.mkstemp') + self.mock_mkstemp = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.mkstemp" + ) self.mkstemp = self.mock_mkstemp.start() self.mkstemp.return_value = mkstemp() self.tmp_file = self.mkstemp.return_value[1] - self.mock_chmod = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.os.chmod') + self.mock_chmod = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.os.chmod" + ) self.chmod = self.mock_chmod.start() - self.mock_get_version = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.get_version') + self.mock_get_version = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.get_version" + ) self.get_version = self.mock_get_version.start() self.get_version.return_value = "2.5.2" self.maxDiff = None def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseModule, self).tearDown() self.mock_parse.stop() @@ -104,27 +127,29 @@ def tearDown(self): raise def get_args_fields(self): - """ return params fields """ + """return params fields""" try: return self.pfmodule.get_argument_spec().keys() except AttributeError: raise NotImplementedError() def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ + """return target elt from XML""" raise NotImplementedError() def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ + """check XML definition of target elt""" raise NotImplementedError() def check_target_elt_direct(self, target_elt, expected_elt_string): - """ check XML definition of target elt against expected XML """ - target_elt_string = ET.tostring(target_elt, encoding="unicode", short_empty_elements=False) + """check XML definition of target elt against expected XML""" + target_elt_string = ET.tostring( + target_elt, encoding="unicode", short_empty_elements=False + ) self.assertEqual(target_elt_string, expected_elt_string) - def args_from_var(self, var, state='present', **kwargs): - """ return arguments for module from var """ + def args_from_var(self, var, state="present", **kwargs): + """return arguments for module from var""" args = {} fields = self.get_args_fields() @@ -133,43 +158,64 @@ def args_from_var(self, var, state='present', **kwargs): args[field] = var[field] if state is not None: - args['state'] = state + args["state"] = state for key, value in kwargs.items(): args[key] = value return args - def execute_module(self, failed=False, changed=False, commands=None, sort=True, defaults=False, msg=''): + def execute_module( + self, + failed=False, + changed=False, + commands=None, + sort=True, + defaults=False, + msg="", + ): self.load_fixtures() if failed: result = self.failed() - self.assertTrue(result['failed'], result) + self.assertTrue(result["failed"], result) else: result = self.changed(changed) if not failed: - self.assertEqual(result['changed'], changed, result) + self.assertEqual(result["changed"], changed, result) else: - self.assertEqual(result['msg'], msg) + self.assertEqual(result["msg"], msg) if commands is not None: if sort: - self.assertEqual(sorted(commands), sorted(result['commands']), result['commands']) + self.assertEqual( + sorted(commands), sorted(result["commands"]), result["commands"] + ) else: - self.assertEqual(commands, result['commands'], result['commands']) + self.assertEqual(commands, result["commands"], result["commands"]) return result - def do_module_test(self, obj, command=None, changed=True, failed=False, msg=None, delete=False, state='present', expected_elt_string=None, **kwargs): - """ run test """ + def do_module_test( + self, + obj, + command=None, + changed=True, + failed=False, + msg=None, + delete=False, + state="present", + expected_elt_string=None, + **kwargs, + ): + """run test""" if command is not None: command = self.strip_commands(command) obj = self.strip_params(obj) if delete: - state = 'absent' + state = "absent" with set_module_args(self.args_from_var(obj, state=state)): result = self.execute_module(changed=changed, failed=failed, msg=msg) @@ -181,17 +227,17 @@ def do_module_test(self, obj, command=None, changed=True, failed=False, msg=None self.assertFalse(self.load_xml_result()) elif not changed: self.assertFalse(self.load_xml_result()) - self.assertEqual(result['commands'], [], result) + self.assertEqual(result["commands"], [], result) elif delete: self.assertTrue(self.load_xml_result()) target_elt = self.get_target_elt(obj, absent=True, module_result=result) self.assertIsNone(target_elt) - self.assertEqual(result['commands'], command, result) + self.assertEqual(result["commands"], command, result) else: self.assertTrue(self.load_xml_result()) target_elt = self.get_target_elt(obj, module_result=result) self.assertIsNotNone(target_elt) - self.assertEqual(result['commands'], command, result) + self.assertEqual(result["commands"], command, result) if expected_elt_string is not None: self.check_target_elt_direct(target_elt, expected_elt_string) else: @@ -202,7 +248,7 @@ def failed(self): self.module.main() result = exc.exception.args[0] - self.assertTrue(result['failed'], result) + self.assertTrue(result["failed"], result) return result def changed(self, changed=False): @@ -211,50 +257,52 @@ def changed(self, changed=False): result = exc.exception.args[0] - if 'diff' in result: + if "diff" in result: changes = dict() - after = dict(result['diff']['after']) - before = dict(result['diff']['before']) + after = dict(result["diff"]["after"]) + before = dict(result["diff"]["before"]) for item in after: if item in before: if after[item] != before[item]: - changes[item] = str(before[item]) + ' -> ' + str(after[item]) + changes[item] = str(before[item]) + " -> " + str(after[item]) del before[item] else: - changes[item] = 'None -> ' + str(after[item]) + changes[item] = "None -> " + str(after[item]) for item in before: - changes[item] = str(before[item]) + ' -> None' + changes[item] = str(before[item]) + " -> None" if changes: - result['changes'] = changes + result["changes"] = changes - self.assertEqual(result['changed'], changed, result) + self.assertEqual(result["changed"], changed, result) return result def strip_commands(self, commands): - """ remove old or new parameters """ + """remove old or new parameters""" return commands def strip_params(self, params): - """ remove old or new parameters """ + """remove old or new parameters""" return params def get_config_file(self): - """ get config file """ + """get config file""" return self.config_file def load_fixtures(self): - """ loading data """ - self.parse.return_value = ElementTree(fromstring(load_fixture(self.get_config_file()))) + """loading data""" + self.parse.return_value = ElementTree( + fromstring(load_fixture(self.get_config_file())) + ) def load_xml_result(self): - """ load the resulting xml if not already loaded """ + """load the resulting xml if not already loaded""" if self.xml_result is None and os.path.getsize(self.tmp_file) > 0: self.xml_result = ET.parse(self.tmp_file) return self.xml_result is not None @staticmethod def find_xml_tag(parent_tag, elt_filter): - """ return alias named name, having type aliastype """ + """return alias named name, having type aliastype""" for tag in parent_tag: found = True for key, value in elt_filter.items(): @@ -271,62 +319,62 @@ def find_xml_tag(parent_tag, elt_filter): return None def assert_xml_elt_value(self, parent_tag_name, elt_filter, elt_name, elt_value): - """ check the xml elt exist and has the exact value given """ + """check the xml elt exist and has the exact value given""" self.load_xml_result() parent_tag = self.xml_result.find(parent_tag_name) if parent_tag is None: - self.fail('Unable to find tag ' + parent_tag_name) + self.fail("Unable to find tag " + parent_tag_name) tag = self.find_xml_tag(parent_tag, elt_filter) if tag is None: - self.fail('Tag not found: ' + json.dumps(elt_filter)) + self.fail("Tag not found: " + json.dumps(elt_filter)) self.assert_xml_elt_equal(tag, elt_name, elt_value) def assert_xml_elt_dict(self, parent_tag_name, elt_filter, elts): - """ check all the xml elt in elts exist and have the exact value given """ + """check all the xml elt in elts exist and have the exact value given""" self.load_xml_result() parent_tag = self.xml_result.find(parent_tag_name) if parent_tag is None: - self.fail('Unable to find tag ' + parent_tag_name) + self.fail("Unable to find tag " + parent_tag_name) tag = self.find_xml_tag(parent_tag, elt_filter) if tag is None: - self.fail('Tag not found: ' + json.dumps(elt_filter)) + self.fail("Tag not found: " + json.dumps(elt_filter)) for elt_name, elt_value in elts.items(): self.assert_xml_elt_equal(tag, elt_name, elt_value) def assert_has_xml_tag(self, parent_tag_name, elt_filter, absent=False): - """ check the xml elt exist (or not if absent is True) """ + """check the xml elt exist (or not if absent is True)""" self.load_xml_result() parent_tag = self.xml_result.find(parent_tag_name) if parent_tag is None: - self.fail('Unable to find tag ' + parent_tag_name) + self.fail("Unable to find tag " + parent_tag_name) tag = self.find_xml_tag(parent_tag, elt_filter) if absent and tag is not None: - self.fail('Tag found: ' + json.dumps(elt_filter)) + self.fail("Tag found: " + json.dumps(elt_filter)) elif not absent and tag is None: - self.fail('Tag not found: ' + json.dumps(elt_filter)) + self.fail("Tag not found: " + json.dumps(elt_filter)) return tag def assert_find_xml_elt(self, tag, elt_name): elt = tag.find(elt_name) if elt is None: - self.fail('Element not found: ' + elt_name) + self.fail("Element not found: " + elt_name) return elt def assert_not_find_xml_elt(self, tag, elt_name): elt = tag.find(elt_name) if elt is not None: - self.fail('Element found: ' + elt_name) + self.fail("Element found: " + elt_name) return elt def assert_xml_elt_equal(self, tag, elt_name, elt_value): elt = tag.find(elt_name) if elt is None: - self.fail('Element not found: ' + elt_name + ' in tag:' + tag) + self.fail("Element not found: " + elt_name + " in tag:" + tag) if isinstance(elt_value, int): value = str(elt_value) @@ -335,21 +383,49 @@ def assert_xml_elt_equal(self, tag, elt_name, elt_value): if elt.text != value: if elt.text is None: - self.fail('Element <' + elt_name + '> differs. Expected: \'' + str(value) + '\' result: None') + self.fail( + "Element <" + + elt_name + + "> differs. Expected: '" + + str(value) + + "' result: None" + ) else: - self.fail('Element <' + elt_name + '> differs. Expected: \'' + str(value) + '\' result: \'' + elt.text + '\'') + self.fail( + "Element <" + + elt_name + + "> differs. Expected: '" + + str(value) + + "' result: '" + + elt.text + + "'" + ) return elt def assert_xml_elt_match(self, tag, elt_name, elt_regex): elt = tag.find(elt_name) if elt is None: - self.fail('Element not found: ' + elt_name) + self.fail("Element not found: " + elt_name) if re.fullmatch(elt_regex, elt.text) is None: if elt.text is None: - self.fail('Element <' + elt_name + '> does not match \'' + elt_regex + '\' result: None') + self.fail( + "Element <" + + elt_name + + "> does not match '" + + elt_regex + + "' result: None" + ) else: - self.fail('Element <' + elt_name + '> does not match \'' + elt_regex + '\' result: \'' + elt.text + '\'') + self.fail( + "Element <" + + elt_name + + "> does not match '" + + elt_regex + + "' result: '" + + elt.text + + "'" + ) return elt def assert_xml_elt_is_none_or_empty(self, tag, elt_name): @@ -357,46 +433,87 @@ def assert_xml_elt_is_none_or_empty(self, tag, elt_name): if elt is None: return elt if elt.text is not None and elt.text: - self.fail('Element <' + elt_name + '> differs. Expected: NoneType result: \'' + elt.text + '\'') + self.fail( + "Element <" + + elt_name + + "> differs. Expected: NoneType result: '" + + elt.text + + "'" + ) return elt def assert_list_xml_elt_equal(self, tag, elt_name, elt_value): elts = tag.findall(elt_name) if elts is None: - self.fail('Element not found: ' + elt_name) + self.fail("Element not found: " + elt_name) elt_value_copy = list(elt_value) elt_texts = [] for elt in elts: if elt.text not in elt_value_copy: if elt.text is None: - self.fail('Element <' + elt_name + '> differs. Expected: \'' + str(elt_value) + '\' result: None') + self.fail( + "Element <" + + elt_name + + "> differs. Expected: '" + + str(elt_value) + + "' result: None" + ) else: - self.fail('Element <' + elt_name + '> differs. Expected: \'' + str(elt_value) + '\' result: \'' + elt.text + '\'') + self.fail( + "Element <" + + elt_name + + "> differs. Expected: '" + + str(elt_value) + + "' result: '" + + elt.text + + "'" + ) elt_value_copy.remove(elt.text) elt_texts.append(elt.text) if len(elt_value_copy): - self.fail('Element <' + elt_name + '> differs. Expected: \'' + str(elt_value) + '\' result: \'' + str(elt_texts) + '\'') + self.fail( + "Element <" + + elt_name + + "> differs. Expected: '" + + str(elt_value) + + "' result: '" + + str(elt_texts) + + "'" + ) return elts @staticmethod def unalias_interface(interface, physical=False): - """ return real alias name if required """ + """return real alias name if required""" res = [] if physical: - interfaces = dict(lan='vmx1', wan='vmx0', opt1='vmx2', vpn='vmx2', opt2='vmx3', vt1='vmx3', opt3='vmx3.100', lan_100='vmx3.100') + interfaces = dict( + lan="vmx1", + wan="vmx0", + opt1="vmx2", + vpn="vmx2", + opt2="vmx3", + vt1="vmx3", + opt3="vmx3.100", + lan_100="vmx3.100", + ) else: - interfaces = dict(lan='lan', wan='wan', vpn='opt1', vt1='opt2', lan_100='opt3') - if interface.startswith('vip:'): - return '_vip602874de0ff00' - for iface in interface.split(','): + interfaces = dict( + lan="lan", wan="wan", vpn="opt1", vt1="opt2", lan_100="opt3" + ) + if interface.startswith("vip:"): + return "_vip602874de0ff00" + for iface in interface.split(","): if interface in interfaces: res.append(interfaces[iface]) else: res.append(iface) - return ','.join(res) + return ",".join(res) - def check_param_equal(self, params, target_elt, param, default=None, xml_field=None, not_find_val=None): - """ if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML """ + def check_param_equal( + self, params, target_elt, param, default=None, xml_field=None, not_find_val=None + ): + """if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML""" if xml_field is None: xml_field = param @@ -412,9 +529,18 @@ def check_param_equal(self, params, target_elt, param, default=None, xml_field=N else: self.assert_xml_elt_is_none_or_empty(target_elt, xml_field) - def check_param_bool(self, params, target_elt, param, default=False, value_true=None, value_false=None, xml_field=None): - """ if param is defined, check the elt exist and text equals value_true, otherwise that it does not exist in XML or - is empty if value_true is not None or equals value_false if set """ + def check_param_bool( + self, + params, + target_elt, + param, + default=False, + value_true=None, + value_false=None, + xml_field=None, + ): + """if param is defined, check the elt exist and text equals value_true, otherwise that it does not exist in XML or + is empty if value_true is not None or equals value_false if set""" if xml_field is None: xml_field = param @@ -433,7 +559,7 @@ def check_param_bool(self, params, target_elt, param, default=False, value_true= self.assert_xml_elt_is_none_or_empty(target_elt, xml_field) def check_value_equal(self, target_elt, xml_field, value, empty=True): - """ if value is defined, check if target_elt has the right value, otherwise that it does not exist in XML """ + """if value is defined, check if target_elt has the right value, otherwise that it does not exist in XML""" if value is None: if empty: self.assert_xml_elt_is_none_or_empty(target_elt, xml_field) @@ -442,8 +568,10 @@ def check_value_equal(self, target_elt, xml_field, value, empty=True): else: self.assert_xml_elt_equal(target_elt, xml_field, value) - def check_param_equal_or_not_find(self, params, target_elt, param, xml_field=None, not_find_val=None, empty=False): - """ if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML """ + def check_param_equal_or_not_find( + self, params, target_elt, param, xml_field=None, not_find_val=None, empty=False + ): + """if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML""" if xml_field is None: xml_field = param if param in params: @@ -457,7 +585,7 @@ def check_param_equal_or_not_find(self, params, target_elt, param, xml_field=Non self.assert_not_find_xml_elt(target_elt, xml_field) def check_param_equal_or_present(self, params, target_elt, param, xml_field=None): - """ if param is defined, check if target_elt has the right value, otherwise that it is present in XML """ + """if param is defined, check if target_elt has the right value, otherwise that it is present in XML""" if xml_field is None: xml_field = param if param in params: @@ -465,8 +593,10 @@ def check_param_equal_or_present(self, params, target_elt, param, xml_field=None else: self.assert_find_xml_elt(target_elt, xml_field) - def check_list_param_equal(self, params, target_elt, param, default=None, xml_field=None, not_find_val=None): - """ if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML """ + def check_list_param_equal( + self, params, target_elt, param, default=None, xml_field=None, not_find_val=None + ): + """if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML""" if xml_field is None: xml_field = param @@ -482,8 +612,10 @@ def check_list_param_equal(self, params, target_elt, param, default=None, xml_fi else: self.assert_xml_elt_is_none_or_empty(target_elt, xml_field) - def check_list_param_equal_or_not_find(self, params, target_elt, param, xml_field=None, not_find_val=None, empty=False): - """ if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML """ + def check_list_param_equal_or_not_find( + self, params, target_elt, param, xml_field=None, not_find_val=None, empty=False + ): + """if param is defined, check if target_elt has the right value, otherwise that it does not exist in XML""" if xml_field is None: xml_field = param if param in params: diff --git a/tests/unit/plugins/modules/test_pfsense_aggregate.py b/tests/unit/plugins/modules/test_pfsense_aggregate.py index a867c809..16b9d2ee 100644 --- a/tests/unit/plugins/modules/test_pfsense_aggregate.py +++ b/tests/unit/plugins/modules/test_pfsense_aggregate.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -10,377 +11,473 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import set_module_args +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + set_module_args, +) from ansible_collections.pfsensible.core.plugins.modules import pfsense_aggregate from .pfsense_module import TestPFSenseModule class TestPFSenseAggregateModule(TestPFSenseModule): - module = pfsense_aggregate def __init__(self, *args, **kwargs): super(TestPFSenseAggregateModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_aggregate_config.xml' + self.config_file = "pfsense_aggregate_config.xml" def assert_find_alias(self, alias): - """ test if an alias exist """ + """test if an alias exist""" self.load_xml_result() - parent_tag = self.xml_result.find('aliases') + parent_tag = self.xml_result.find("aliases") if parent_tag is None: - self.fail('Unable to find tag aliases') + self.fail("Unable to find tag aliases") tag = self.find_xml_tag(parent_tag, dict(name=alias)) if tag is None: - self.fail('Alias not found: ' + alias) + self.fail("Alias not found: " + alias) def assert_not_find_alias(self, alias): - """ test if an alias does not exist """ + """test if an alias does not exist""" self.load_xml_result() - parent_tag = self.xml_result.find('aliases') + parent_tag = self.xml_result.find("aliases") if parent_tag is None: - self.fail('Unable to find tag aliases') + self.fail("Unable to find tag aliases") tag = self.find_xml_tag(parent_tag, dict(name=alias)) if tag is not None: - self.fail('Alias found: ' + alias) + self.fail("Alias found: " + alias) def assert_find_rule(self, rule, interface): - """ test if a rule exist on interface """ + """test if a rule exist on interface""" self.load_xml_result() - parent_tag = self.xml_result.find('filter') + parent_tag = self.xml_result.find("filter") if parent_tag is None: - self.fail('Unable to find tag filter') + self.fail("Unable to find tag filter") tag = self.find_xml_tag(parent_tag, dict(descr=rule, interface=interface)) if tag is None: - self.fail('Rule not found: ' + rule) + self.fail("Rule not found: " + rule) def assert_not_find_rule(self, rule, interface): - """ test if a rule does not exist on interface """ + """test if a rule does not exist on interface""" self.load_xml_result() - parent_tag = self.xml_result.find('filter') + parent_tag = self.xml_result.find("filter") if parent_tag is None: - self.fail('Unable to find tag filter') + self.fail("Unable to find tag filter") tag = self.find_xml_tag(parent_tag, dict(descr=rule, interface=interface)) if tag is not None: - self.fail('Rule found: ' + rule + ' on ' + interface) + self.fail("Rule found: " + rule + " on " + interface) def assert_find_rule_separator(self, separator, interface): - """ test if a rule separator exist on interface """ + """test if a rule separator exist on interface""" self.load_xml_result() interface = self.unalias_interface(interface) - parent_tag = self.xml_result.find('filter') + parent_tag = self.xml_result.find("filter") if parent_tag is None: - self.fail('Unable to find tag filter') + self.fail("Unable to find tag filter") - separators_elt = parent_tag.find('separator') + separators_elt = parent_tag.find("separator") if parent_tag is None: - self.fail('Unable to find tag separator') + self.fail("Unable to find tag separator") interface_elt = separators_elt.find(interface) if parent_tag is None: - self.fail('Unable to find tag ' + interface) + self.fail("Unable to find tag " + interface) tag = self.find_xml_tag(interface_elt, dict(text=separator)) if tag is None: - self.fail('Separator not found: ' + separator) + self.fail("Separator not found: " + separator) def assert_not_find_rule_separator(self, separator, interface): - """ test if a rule separator dost exist on interface """ + """test if a rule separator dost exist on interface""" self.load_xml_result() interface = self.unalias_interface(interface) - parent_tag = self.xml_result.find('filter') + parent_tag = self.xml_result.find("filter") if parent_tag is None: - self.fail('Unable to find tag filter') + self.fail("Unable to find tag filter") - separators_elt = parent_tag.find('separator') + separators_elt = parent_tag.find("separator") if parent_tag is None: - self.fail('Unable to find tag separator') + self.fail("Unable to find tag separator") interface_elt = separators_elt.find(interface) if parent_tag is None: - self.fail('Unable to find tag ' + interface) + self.fail("Unable to find tag " + interface) tag = self.find_xml_tag(interface_elt, dict(text=separator)) if tag is not None: - self.fail('Separator found: ' + separator) + self.fail("Separator found: " + separator) def assert_find_vlan(self, interface, vlan_id): - """ test if a vlan exist """ + """test if a vlan exist""" self.load_xml_result() - parent_tag = self.xml_result.find('vlans') + parent_tag = self.xml_result.find("vlans") if parent_tag is None: - self.fail('Unable to find tag vlans') + self.fail("Unable to find tag vlans") elt_filter = {} - elt_filter['if'] = interface - elt_filter['tag'] = vlan_id + elt_filter["if"] = interface + elt_filter["tag"] = vlan_id tag = self.find_xml_tag(parent_tag, elt_filter) if tag is None: - self.fail('Vlan not found: {0}.{1}'.format(interface, vlan_id)) + self.fail("Vlan not found: {0}.{1}".format(interface, vlan_id)) def assert_not_find_vlan(self, interface, vlan_id): - """ test if an vlan does not exist """ + """test if an vlan does not exist""" self.load_xml_result() - parent_tag = self.xml_result.find('vlans') + parent_tag = self.xml_result.find("vlans") if parent_tag is None: - self.fail('Unable to find tag vlans') + self.fail("Unable to find tag vlans") elt_filter = {} - elt_filter['if'] = interface - elt_filter['vlan_id'] = vlan_id + elt_filter["if"] = interface + elt_filter["vlan_id"] = vlan_id tag = self.find_xml_tag(parent_tag, elt_filter) if tag is not None: - self.fail('Vlan found: {0}.{1}'.format(interface, vlan_id)) + self.fail("Vlan found: {0}.{1}".format(interface, vlan_id)) ############ # as we rely on sub modules for modifying the xml # we dont perform extensive checks on the xml modifications # we just test if elements are created or deleted, and the respective output def test_aggregate_aliases(self): - """ test creation of a some aliases """ + """test creation of a some aliases""" args = dict( purge_aliases=False, aggregated_aliases=[ - dict(name='one_host', type='host', address='10.9.8.7'), - dict(name='another_host', type='host', address='10.9.8.6'), - dict(name='one_server', type='host', address='192.168.1.165', descr='', detail=''), - dict(name='port_ssh', type='port', address='2222'), - dict(name='port_http', state='absent'), - ] + dict(name="one_host", type="host", address="10.9.8.7"), + dict(name="another_host", type="host", address="10.9.8.6"), + dict( + name="one_server", + type="host", + address="192.168.1.165", + descr="", + detail="", + ), + dict(name="port_ssh", type="port", address="2222"), + dict(name="port_http", state="absent"), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_aliases = [] - result_aliases.append("create alias 'one_host', type='host', address='10.9.8.7'") - result_aliases.append("create alias 'another_host', type='host', address='10.9.8.6'") + result_aliases.append( + "create alias 'one_host', type='host', address='10.9.8.7'" + ) + result_aliases.append( + "create alias 'another_host', type='host', address='10.9.8.6'" + ) result_aliases.append("update alias 'port_ssh' set address='2222'") result_aliases.append("delete alias 'port_http'") - self.assertEqual(result['result_aliases'], result_aliases) - self.assert_find_alias('one_host') - self.assert_find_alias('another_host') - self.assert_find_alias('one_server') - self.assert_find_alias('port_ssh') - self.assert_not_find_alias('port_http') - self.assert_find_alias('port_dns') + self.assertEqual(result["result_aliases"], result_aliases) + self.assert_find_alias("one_host") + self.assert_find_alias("another_host") + self.assert_find_alias("one_server") + self.assert_find_alias("port_ssh") + self.assert_not_find_alias("port_http") + self.assert_find_alias("port_dns") def test_aggregate_aliases_checkmode(self): - """ test creation of a some aliases with check_mode """ + """test creation of a some aliases with check_mode""" args = dict( purge_aliases=False, aggregated_aliases=[ - dict(name='one_host', type='host', address='10.9.8.7'), - dict(name='another_host', type='host', address='10.9.8.6'), - dict(name='one_server', type='host', address='192.168.1.165', descr='', detail=''), - dict(name='port_ssh', type='port', address='2222'), - dict(name='port_http', state='absent'), + dict(name="one_host", type="host", address="10.9.8.7"), + dict(name="another_host", type="host", address="10.9.8.6"), + dict( + name="one_server", + type="host", + address="192.168.1.165", + descr="", + detail="", + ), + dict(name="port_ssh", type="port", address="2222"), + dict(name="port_http", state="absent"), ], _ansible_check_mode=True, ) with set_module_args(args): result = self.execute_module(changed=True) result_aliases = [] - result_aliases.append("create alias 'one_host', type='host', address='10.9.8.7'") - result_aliases.append("create alias 'another_host', type='host', address='10.9.8.6'") + result_aliases.append( + "create alias 'one_host', type='host', address='10.9.8.7'" + ) + result_aliases.append( + "create alias 'another_host', type='host', address='10.9.8.6'" + ) result_aliases.append("update alias 'port_ssh' set address='2222'") result_aliases.append("delete alias 'port_http'") - self.assertEqual(result['result_aliases'], result_aliases) + self.assertEqual(result["result_aliases"], result_aliases) self.assertFalse(self.load_xml_result()) def test_aggregate_aliases_purge(self): - """ test creation of a some aliases with purge """ + """test creation of a some aliases with purge""" args = dict( purge_aliases=True, purge_rules=False, aggregated_aliases=[ - dict(name='one_host', type='host', address='10.9.8.7'), - dict(name='another_host', type='host', address='10.9.8.6'), - dict(name='one_server', type='host', address='192.168.1.165', descr='', detail=''), - dict(name='port_ssh', type='port', address='2222'), - dict(name='port_http', state='absent'), - ] + dict(name="one_host", type="host", address="10.9.8.7"), + dict(name="another_host", type="host", address="10.9.8.6"), + dict( + name="one_server", + type="host", + address="192.168.1.165", + descr="", + detail="", + ), + dict(name="port_ssh", type="port", address="2222"), + dict(name="port_http", state="absent"), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_aliases = [] - result_aliases.append("create alias 'one_host', type='host', address='10.9.8.7'") - result_aliases.append("create alias 'another_host', type='host', address='10.9.8.6'") + result_aliases.append( + "create alias 'one_host', type='host', address='10.9.8.7'" + ) + result_aliases.append( + "create alias 'another_host', type='host', address='10.9.8.6'" + ) result_aliases.append("update alias 'port_ssh' set address='2222'") result_aliases.append("delete alias 'port_http'") result_aliases.append("delete alias 'port_dns'") - self.assertEqual(result['result_aliases'], result_aliases) - self.assert_find_alias('one_host') - self.assert_find_alias('another_host') - self.assert_find_alias('one_server') - self.assert_find_alias('port_ssh') - self.assert_not_find_alias('port_http') - self.assert_not_find_alias('port_dns') + self.assertEqual(result["result_aliases"], result_aliases) + self.assert_find_alias("one_host") + self.assert_find_alias("another_host") + self.assert_find_alias("one_server") + self.assert_find_alias("port_ssh") + self.assert_not_find_alias("port_http") + self.assert_not_find_alias("port_dns") def test_aggregate_rules(self): - """ test creation of a some rules """ + """test creation of a some rules""" args = dict( purge_rules=False, aggregated_rules=[ - dict(name='one_rule', source='any', destination='any', interface='lan'), - dict(name='any2any_ssh', source='any', destination='any:2222', interface='lan', protocol='tcp'), - dict(name='any2any_http', source='any', destination='any:8080', interface='vpn', protocol='tcp'), - dict(name='any2any_ssh', state='absent', interface='vpn'), - ] + dict(name="one_rule", source="any", destination="any", interface="lan"), + dict( + name="any2any_ssh", + source="any", + destination="any:2222", + interface="lan", + protocol="tcp", + ), + dict( + name="any2any_http", + source="any", + destination="any:8080", + interface="vpn", + protocol="tcp", + ), + dict(name="any2any_ssh", state="absent", interface="vpn"), + ], ) with set_module_args(args): self.execute_module(changed=True) - self.assert_find_rule('one_rule', 'lan') - self.assert_find_rule('any2any_ssh', 'lan') - self.assert_find_rule('any2any_http', 'lan') - self.assert_find_rule('any2any_https', 'lan') - self.assert_not_find_rule('any2any_ssh', 'opt1') - self.assert_find_rule('any2any_http', 'opt1') - self.assert_find_rule('any2any_https', 'opt1') + self.assert_find_rule("one_rule", "lan") + self.assert_find_rule("any2any_ssh", "lan") + self.assert_find_rule("any2any_http", "lan") + self.assert_find_rule("any2any_https", "lan") + self.assert_not_find_rule("any2any_ssh", "opt1") + self.assert_find_rule("any2any_http", "opt1") + self.assert_find_rule("any2any_https", "opt1") def test_aggregate_rules_purge(self): - """ test creation of a some rules with purge """ + """test creation of a some rules with purge""" args = dict( purge_rules=True, aggregated_rules=[ - dict(name='one_rule', source='any', destination='any', interface='lan'), - dict(name='any2any_ssh', source='any', destination='any:2222', interface='lan', protocol='tcp'), - dict(name='any2any_http', source='any', destination='any:8080', interface='vpn', protocol='tcp'), - dict(name='any2any_ssh', state='absent', interface='vpn'), - ] + dict(name="one_rule", source="any", destination="any", interface="lan"), + dict( + name="any2any_ssh", + source="any", + destination="any:2222", + interface="lan", + protocol="tcp", + ), + dict( + name="any2any_http", + source="any", + destination="any:8080", + interface="vpn", + protocol="tcp", + ), + dict(name="any2any_ssh", state="absent", interface="vpn"), + ], ) with set_module_args(args): self.execute_module(changed=True) - self.assert_find_rule('one_rule', 'lan') - self.assert_find_rule('any2any_ssh', 'lan') - self.assert_not_find_rule('any2any_http', 'lan') - self.assert_not_find_rule('any2any_https', 'lan') - self.assert_not_find_rule('any2any_ssh', 'opt1') - self.assert_find_rule('any2any_http', 'opt1') - self.assert_not_find_rule('any2any_https', 'opt1') + self.assert_find_rule("one_rule", "lan") + self.assert_find_rule("any2any_ssh", "lan") + self.assert_not_find_rule("any2any_http", "lan") + self.assert_not_find_rule("any2any_https", "lan") + self.assert_not_find_rule("any2any_ssh", "opt1") + self.assert_find_rule("any2any_http", "opt1") + self.assert_not_find_rule("any2any_https", "opt1") def test_aggregate_separators(self): - """ test creation of a some separators """ + """test creation of a some separators""" args = dict( purge_rule_separators=False, aggregated_rule_separators=[ - dict(name='one_separator', interface='lan'), - dict(name='another_separator', interface='lan_100'), - dict(name='another_test_separator', interface='lan', state='absent'), - dict(name='test_separator', interface='lan', before='bottom', color='warning'), - ] + dict(name="one_separator", interface="lan"), + dict(name="another_separator", interface="lan_100"), + dict(name="another_test_separator", interface="lan", state="absent"), + dict( + name="test_separator", + interface="lan", + before="bottom", + color="warning", + ), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_separators = [] - result_separators.append("create rule_separator 'one_separator' on 'lan', color='info'") - result_separators.append("create rule_separator 'another_separator' on 'lan_100', color='info'") - result_separators.append("delete rule_separator 'another_test_separator' on 'lan'") - result_separators.append("update rule_separator 'test_separator' on 'lan' set color='warning', before='bottom'") - - self.assertEqual(result['result_rule_separators'], result_separators) - self.assert_find_rule_separator('one_separator', 'lan') - self.assert_find_rule_separator('another_separator', 'lan_100') - self.assert_not_find_rule_separator('another_test_separator', 'lan') - self.assert_find_rule_separator('test_separator', 'lan') + result_separators.append( + "create rule_separator 'one_separator' on 'lan', color='info'" + ) + result_separators.append( + "create rule_separator 'another_separator' on 'lan_100', color='info'" + ) + result_separators.append( + "delete rule_separator 'another_test_separator' on 'lan'" + ) + result_separators.append( + "update rule_separator 'test_separator' on 'lan' set color='warning', before='bottom'" + ) + + self.assertEqual(result["result_rule_separators"], result_separators) + self.assert_find_rule_separator("one_separator", "lan") + self.assert_find_rule_separator("another_separator", "lan_100") + self.assert_not_find_rule_separator("another_test_separator", "lan") + self.assert_find_rule_separator("test_separator", "lan") def test_aggregate_separators_purge(self): - """ test creation of a some separators with purge """ + """test creation of a some separators with purge""" args = dict( purge_rule_separators=True, aggregated_rule_separators=[ - dict(name='one_separator', interface='lan'), - dict(name='another_separator', interface='lan_100'), - dict(name='another_test_separator', interface='lan', state='absent'), - dict(name='test_separator', interface='lan', before='bottom', color='warning'), - ] + dict(name="one_separator", interface="lan"), + dict(name="another_separator", interface="lan_100"), + dict(name="another_test_separator", interface="lan", state="absent"), + dict( + name="test_separator", + interface="lan", + before="bottom", + color="warning", + ), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_separators = [] - result_separators.append("create rule_separator 'one_separator' on 'lan', color='info'") - result_separators.append("create rule_separator 'another_separator' on 'lan_100', color='info'") - result_separators.append("delete rule_separator 'another_test_separator' on 'lan'") - result_separators.append("update rule_separator 'test_separator' on 'lan' set color='warning', before='bottom'") + result_separators.append( + "create rule_separator 'one_separator' on 'lan', color='info'" + ) + result_separators.append( + "create rule_separator 'another_separator' on 'lan_100', color='info'" + ) + result_separators.append( + "delete rule_separator 'another_test_separator' on 'lan'" + ) + result_separators.append( + "update rule_separator 'test_separator' on 'lan' set color='warning', before='bottom'" + ) result_separators.append("delete rule_separator 'test_separator' on 'wan'") - result_separators.append("delete rule_separator 'last_test_separator' on 'lan'") - result_separators.append("delete rule_separator 'test_sep_floating' on 'floating'") - - self.assertEqual(result['result_rule_separators'], result_separators) - self.assert_find_rule_separator('one_separator', 'lan') - self.assert_find_rule_separator('another_separator', 'lan_100') - self.assert_not_find_rule_separator('another_test_separator', 'lan') - self.assert_find_rule_separator('test_separator', 'lan') - self.assert_not_find_rule_separator('last_test_separator', 'lan') - self.assert_not_find_rule_separator('test_sep_floating', 'floatingrules') + result_separators.append( + "delete rule_separator 'last_test_separator' on 'lan'" + ) + result_separators.append( + "delete rule_separator 'test_sep_floating' on 'floating'" + ) + + self.assertEqual(result["result_rule_separators"], result_separators) + self.assert_find_rule_separator("one_separator", "lan") + self.assert_find_rule_separator("another_separator", "lan_100") + self.assert_not_find_rule_separator("another_test_separator", "lan") + self.assert_find_rule_separator("test_separator", "lan") + self.assert_not_find_rule_separator("last_test_separator", "lan") + self.assert_not_find_rule_separator("test_sep_floating", "floatingrules") def test_aggregate_nat_outbound(self): - """ test creation of some nat outbound """ + """test creation of some nat outbound""" args = dict( purge_nat_outbounds=True, aggregated_nat_outbounds=[ - dict(descr='snat 1', source='192.168.100.0/24', destination='1.1.1.0/24', interface='lan', staticnatport=True), - ] + dict( + descr="snat 1", + source="192.168.100.0/24", + destination="1.1.1.0/24", + interface="lan", + staticnatport=True, + ), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_nat_outbounds = [] result_nat_outbounds.append("delete nat_outbound 'None'") result_nat_outbounds.append( - "create nat_outbound 'snat 1', interface='lan', source='192.168.100.0/24', destination='1.1.1.0/24', staticnatport=True") + "create nat_outbound 'snat 1', interface='lan', source='192.168.100.0/24', destination='1.1.1.0/24', staticnatport=True" + ) - self.assertEqual(result['result_nat_outbounds'], result_nat_outbounds) + self.assertEqual(result["result_nat_outbounds"], result_nat_outbounds) def test_aggregate_vlans(self): - """ test creation of some vlans """ + """test creation of some vlans""" args = dict( purge_vlans=False, aggregated_vlans=[ - dict(vlan_id=100, interface='vmx0', descr='voice'), - dict(vlan_id=1200, interface='vmx1', state='absent'), - dict(vlan_id=101, interface='vmx1', descr='printers'), - dict(vlan_id=102, interface='vmx2', descr='users'), - ] + dict(vlan_id=100, interface="vmx0", descr="voice"), + dict(vlan_id=1200, interface="vmx1", state="absent"), + dict(vlan_id=101, interface="vmx1", descr="printers"), + dict(vlan_id=102, interface="vmx2", descr="users"), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_aliases = [] result_aliases.append("update vlan 'vmx0.100' set descr='voice'") result_aliases.append("delete vlan 'vmx1.1200'") - result_aliases.append("create vlan 'vmx1.101', descr='printers', priority=''") + result_aliases.append( + "create vlan 'vmx1.101', descr='printers', priority=''" + ) result_aliases.append("create vlan 'vmx2.102', descr='users', priority=''") - self.assertEqual(result['result_vlans'], result_aliases) - self.assert_find_vlan('vmx0', '100') - self.assert_not_find_vlan('vmx1', '1200') - self.assert_find_vlan('vmx1', '101') - self.assert_find_vlan('vmx2', '102') + self.assertEqual(result["result_vlans"], result_aliases) + self.assert_find_vlan("vmx0", "100") + self.assert_not_find_vlan("vmx1", "1200") + self.assert_find_vlan("vmx1", "101") + self.assert_find_vlan("vmx2", "102") def test_aggregate_vlans_with_purge(self): - """ test creation of some vlans with purge""" + """test creation of some vlans with purge""" args = dict( purge_vlans=True, aggregated_vlans=[ - dict(vlan_id=1100, interface='vmx1'), - dict(vlan_id=1200, interface='vmx1', state='absent'), - dict(vlan_id=101, interface='vmx1', descr='printers'), - dict(vlan_id=102, interface='vmx2', descr='users'), - ] + dict(vlan_id=1100, interface="vmx1"), + dict(vlan_id=1200, interface="vmx1", state="absent"), + dict(vlan_id=101, interface="vmx1", descr="printers"), + dict(vlan_id=102, interface="vmx2", descr="users"), + ], ) with set_module_args(args): result = self.execute_module(changed=True) result_aliases = [] result_aliases.append("delete vlan 'vmx1.1200'") - result_aliases.append("create vlan 'vmx1.101', descr='printers', priority=''") + result_aliases.append( + "create vlan 'vmx1.101', descr='printers', priority=''" + ) result_aliases.append("create vlan 'vmx2.102', descr='users', priority=''") result_aliases.append("delete vlan 'vmx0.100'") - self.assertEqual(result['result_vlans'], result_aliases) - self.assert_not_find_vlan('vmx1', '1200') - self.assert_find_vlan('vmx1', '101') - self.assert_find_vlan('vmx2', '102') - self.assert_not_find_vlan('vmx0', '100') + self.assertEqual(result["result_vlans"], result_aliases) + self.assert_not_find_vlan("vmx1", "1200") + self.assert_find_vlan("vmx1", "101") + self.assert_find_vlan("vmx2", "102") + self.assert_not_find_vlan("vmx0", "100") diff --git a/tests/unit/plugins/modules/test_pfsense_alias.py b/tests/unit/plugins/modules/test_pfsense_alias.py index 136d5442..5d0eaa9c 100644 --- a/tests/unit/plugins/modules/test_pfsense_alias.py +++ b/tests/unit/plugins/modules/test_pfsense_alias.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type from copy import copy @@ -11,20 +12,23 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import set_module_args +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + set_module_args, +) from ansible_collections.pfsensible.core.plugins.modules import pfsense_alias -from ansible_collections.pfsensible.core.plugins.module_utils.alias import PFSenseAliasModule +from ansible_collections.pfsensible.core.plugins.module_utils.alias import ( + PFSenseAliasModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseAliasModule(TestPFSenseModule): - module = pfsense_alias def __init__(self, *args, **kwargs): super(TestPFSenseAliasModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_alias_config.xml' + self.config_file = "pfsense_alias_config.xml" self.pfmodule = PFSenseAliasModule ######################################################## @@ -32,46 +36,60 @@ def __init__(self, *args, **kwargs): # First we run the module # Then, we check return values # Finally, we check the xml - def do_alias_creation_test(self, alias, set_after=None, unset_after=None, failed=False, msg='', command=None): - """ test creation of a new alias """ + def do_alias_creation_test( + self, + alias, + set_after=None, + unset_after=None, + failed=False, + msg="", + command=None, + ): + """test creation of a new alias""" with set_module_args(self.args_from_var(alias)): result = self.execute_module(changed=True, failed=failed, msg=msg) if not failed: diff = dict(before={}, after=alias) if set_after is not None: - diff['after'].update(set_after) + diff["after"].update(set_after) if unset_after is not None: for n in unset_after: - del diff['after'][n] - self.assertEqual(result['diff'], diff) - self.assert_xml_elt_dict('aliases', dict(name=alias['name'], type=alias['type']), diff['after']) - self.assertEqual(result['commands'], [command]) + del diff["after"][n] + self.assertEqual(result["diff"], diff) + self.assert_xml_elt_dict( + "aliases", + dict(name=alias["name"], type=alias["type"]), + diff["after"], + ) + self.assertEqual(result["commands"], [command]) else: self.assertFalse(self.load_xml_result()) def do_alias_deletion_test(self, alias, command=None): - """ test deletion of an alias """ - with set_module_args(self.args_from_var(alias, 'absent')): + """test deletion of an alias""" + with set_module_args(self.args_from_var(alias, "absent")): result = self.execute_module(changed=True) diff = dict(before=alias, after={}) - self.assertEqual(result['diff'], diff) - self.assert_has_xml_tag('aliases', dict(name=alias['name'], type=alias['type']), absent=True) - self.assertEqual(result['commands'], [command]) + self.assertEqual(result["diff"], diff) + self.assert_has_xml_tag( + "aliases", dict(name=alias["name"], type=alias["type"]), absent=True + ) + self.assertEqual(result["commands"], [command]) def do_alias_update_noop_test(self, alias): - """ test not updating an alias """ + """test not updating an alias""" with set_module_args(self.args_from_var(alias)): result = self.execute_module(changed=False) diff = dict(before=alias, after=alias) - self.assertEqual(result['diff'], diff) + self.assertEqual(result["diff"], diff) self.assertFalse(self.load_xml_result()) - self.assertEqual(result['commands'], []) + self.assertEqual(result["commands"], []) def do_alias_update_field(self, alias, set_after=None, command=None, **kwargs): - """ test updating field of an host alias """ + """test updating field of an host alias""" target = copy(alias) target.update(kwargs) with set_module_args(self.args_from_var(target)): @@ -79,168 +97,294 @@ def do_alias_update_field(self, alias, set_after=None, command=None, **kwargs): diff = dict(before=alias, after=copy(target)) if set_after is not None: - diff['after'].update(set_after) - self.assertEqual(result['diff'], diff) - if alias['type'] in ['host', 'port', 'network']: - self.assert_xml_elt_value('aliases', dict(name=alias['name'], type=alias['type']), 'address', diff['after']['address']) + diff["after"].update(set_after) + self.assertEqual(result["diff"], diff) + if alias["type"] in ["host", "port", "network"]: + self.assert_xml_elt_value( + "aliases", + dict(name=alias["name"], type=alias["type"]), + "address", + diff["after"]["address"], + ) else: - self.assert_xml_elt_value('aliases', dict(name=alias['name'], type=alias['type']), 'url', diff['after']['url']) - self.assertEqual(result['commands'], [command]) + self.assert_xml_elt_value( + "aliases", + dict(name=alias["name"], type=alias["type"]), + "url", + diff["after"]["url"], + ) + self.assertEqual(result["commands"], [command]) ############## # hosts # def test_host_create(self): - """ test creation of a new host alias """ - alias = dict(name='adservers', address='10.0.0.1 10.0.0.2', descr='', type='host', detail='') + """test creation of a new host alias""" + alias = dict( + name="adservers", + address="10.0.0.1 10.0.0.2", + descr="", + type="host", + detail="", + ) command = "create alias 'adservers', type='host', address='10.0.0.1 10.0.0.2', descr='', detail=''" self.do_alias_creation_test(alias, command=command) def test_host_delete(self): - """ test deletion of an host alias """ - alias = dict(name='ad_poc1', address='192.168.1.3', descr='', type='host', detail='') + """test deletion of an host alias""" + alias = dict( + name="ad_poc1", address="192.168.1.3", descr="", type="host", detail="" + ) command = "delete alias 'ad_poc1'" self.do_alias_deletion_test(alias, command=command) def test_host_update_noop(self): - """ test not updating an host alias """ - alias = dict(name='ad_poc1', address='192.168.1.3', descr='', type='host', detail='') + """test not updating an host alias""" + alias = dict( + name="ad_poc1", address="192.168.1.3", descr="", type="host", detail="" + ) self.do_alias_update_noop_test(alias) def test_host_update_ip(self): - """ test updating address of an host alias """ - alias = dict(name='ad_poc1', address='192.168.1.3', descr='', type='host', detail='') + """test updating address of an host alias""" + alias = dict( + name="ad_poc1", address="192.168.1.3", descr="", type="host", detail="" + ) command = "update alias 'ad_poc1' set address='192.168.1.4'" - self.do_alias_update_field(alias, address='192.168.1.4', command=command) + self.do_alias_update_field(alias, address="192.168.1.4", command=command) def test_host_update_descr(self): - """ test updating descr of an host alias """ - alias = dict(name='ad_poc1', address='192.168.1.3', descr='', type='host', detail='') + """test updating descr of an host alias""" + alias = dict( + name="ad_poc1", address="192.168.1.3", descr="", type="host", detail="" + ) command = "update alias 'ad_poc1' set descr='ad server'" - self.do_alias_update_field(alias, descr='ad server', command=command) + self.do_alias_update_field(alias, descr="ad server", command=command) ############## # ports # def test_port_create(self): - """ test creation of a new port alias """ - alias = dict(name='port_proxy', address='8080 8443', descr='', type='port', detail='') + """test creation of a new port alias""" + alias = dict( + name="port_proxy", address="8080 8443", descr="", type="port", detail="" + ) command = "create alias 'port_proxy', type='port', address='8080 8443', descr='', detail=''" self.do_alias_creation_test(alias, command=command) def test_port_delete(self): - """ test deletion of a port alias """ - alias = dict(name='port_ssh', address='22', descr='', type='port', detail='') + """test deletion of a port alias""" + alias = dict(name="port_ssh", address="22", descr="", type="port", detail="") command = "delete alias 'port_ssh'" self.do_alias_deletion_test(alias, command=command) def test_port_update_noop(self): - """ test not updating a port alias """ - alias = dict(name='port_ssh', address='22', descr='', type='port', detail='') + """test not updating a port alias""" + alias = dict(name="port_ssh", address="22", descr="", type="port", detail="") self.do_alias_update_noop_test(alias) def test_port_update_port(self): - """ test updating port of a port alias """ - alias = dict(name='port_ssh', address='22', descr='', type='port', detail='') + """test updating port of a port alias""" + alias = dict(name="port_ssh", address="22", descr="", type="port", detail="") command = "update alias 'port_ssh' set address='2222'" - self.do_alias_update_field(alias, address='2222', command=command) + self.do_alias_update_field(alias, address="2222", command=command) def test_port_update_descr(self): - """ test updating descr of a port alias """ - alias = dict(name='port_ssh', address='22', descr='', type='port', detail='') + """test updating descr of a port alias""" + alias = dict(name="port_ssh", address="22", descr="", type="port", detail="") command = "update alias 'port_ssh' set descr='ssh port'" - self.do_alias_update_field(alias, descr='ssh port', command=command) + self.do_alias_update_field(alias, descr="ssh port", command=command) ############## # networks # def test_network_create(self): - """ test creation of a new network alias """ - alias = dict(name='data_networks', address='192.168.1.0/24 192.168.2.0/24', descr='', type='network', detail='') + """test creation of a new network alias""" + alias = dict( + name="data_networks", + address="192.168.1.0/24 192.168.2.0/24", + descr="", + type="network", + detail="", + ) command = "create alias 'data_networks', type='network', address='192.168.1.0/24 192.168.2.0/24', descr='', detail=''" self.do_alias_creation_test(alias, command=command) def test_network_delete(self): - """ test deletion of a network alias """ - alias = dict(name='lan_data_poc3', address='192.168.3.0/24', descr='', type='network', detail='') + """test deletion of a network alias""" + alias = dict( + name="lan_data_poc3", + address="192.168.3.0/24", + descr="", + type="network", + detail="", + ) command = "delete alias 'lan_data_poc3'" self.do_alias_deletion_test(alias, command=command) def test_network_update_noop(self): - """ test not updating a network alias """ - alias = dict(name='lan_data_poc3', address='192.168.3.0/24', descr='', type='network', detail='') + """test not updating a network alias""" + alias = dict( + name="lan_data_poc3", + address="192.168.3.0/24", + descr="", + type="network", + detail="", + ) self.do_alias_update_noop_test(alias) def test_network_update_network(self): - """ test updating address of a network alias """ - alias = dict(name='lan_data_poc3', address='192.168.3.0/24', descr='', type='network', detail='') + """test updating address of a network alias""" + alias = dict( + name="lan_data_poc3", + address="192.168.3.0/24", + descr="", + type="network", + detail="", + ) command = "update alias 'lan_data_poc3' set address='192.168.2.0/24'" - self.do_alias_update_field(alias, address='192.168.2.0/24', command=command) + self.do_alias_update_field(alias, address="192.168.2.0/24", command=command) def test_network_update_descr(self): - """ test updating descr of a network alias """ - alias = dict(name='lan_data_poc3', address='192.168.3.0/24', descr='', type='network', detail='') + """test updating descr of a network alias""" + alias = dict( + name="lan_data_poc3", + address="192.168.3.0/24", + descr="", + type="network", + detail="", + ) command = "update alias 'lan_data_poc3' set descr='data network'" - self.do_alias_update_field(alias, descr='data network', command=command) + self.do_alias_update_field(alias, descr="data network", command=command) ############## # urltables # def test_urltable_create(self): - """ test creation of a new urltable alias """ - alias = dict(name='acme_table', address='http://www.acme.com', descr='', type='urltable', updatefreq='10', detail='') + """test creation of a new urltable alias""" + alias = dict( + name="acme_table", + address="http://www.acme.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) command = "create alias 'acme_table', type='urltable', url='http://www.acme.com', descr='', detail='', updatefreq='10'" - self.do_alias_creation_test(alias, command=command, set_after=dict(url='http://www.acme.com'), unset_after=['address']) + self.do_alias_creation_test( + alias, + command=command, + set_after=dict(url="http://www.acme.com"), + unset_after=["address"], + ) def test_urltable_create_url(self): - """ test creation of a new urltable alias """ - alias = dict(name='acme_table', url='http://www.acme.com', descr='', type='urltable', updatefreq='10', detail='') + """test creation of a new urltable alias""" + alias = dict( + name="acme_table", + url="http://www.acme.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) command = "create alias 'acme_table', type='urltable', url='http://www.acme.com', descr='', detail='', updatefreq='10'" self.do_alias_creation_test(alias, command=command) def test_urltable_create_exclusive(self): - """ test creattion of a urltable alias with both address and url - fails """ + """test creattion of a urltable alias with both address and url - fails""" alias = dict( - name='acme_corp', address='http://www.acme-corp.com', url='http://www.acme-corp.com', descr='', type='urltable', updatefreq='10', detail='') - self.do_alias_creation_test(alias, failed=True, msg='parameters are mutually exclusive: address|url') + name="acme_corp", + address="http://www.acme-corp.com", + url="http://www.acme-corp.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) + self.do_alias_creation_test( + alias, failed=True, msg="parameters are mutually exclusive: address|url" + ) def test_urltable_delete(self): - """ test deletion of a urltable alias """ + """test deletion of a urltable alias""" alias = dict( - name='acme_corp', url='http://www.acme-corp.com', descr='', type='urltable', updatefreq='10', detail='') + name="acme_corp", + url="http://www.acme-corp.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) command = "delete alias 'acme_corp'" self.do_alias_deletion_test(alias, command=command) def test_urltable_update_noop(self): - """ test not updating a urltable alias """ + """test not updating a urltable alias""" alias = dict( - name='acme_corp', url='http://www.acme-corp.com', descr='', type='urltable', updatefreq='10', detail='') + name="acme_corp", + url="http://www.acme-corp.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) self.do_alias_update_noop_test(alias) def test_urltable_update_url(self): - """ test updating url of a urltable alias """ + """test updating url of a urltable alias""" alias = dict( - name='acme_corp', url='http://www.acme-corp.com', descr='', type='urltable', updatefreq='10', detail='') + name="acme_corp", + url="http://www.acme-corp.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) command = "update alias 'acme_corp' set url='http://www.new-acme-corp.com'" - self.do_alias_update_field(alias, url='http://www.new-acme-corp.com', set_after=dict(url='http://www.new-acme-corp.com'), command=command) + self.do_alias_update_field( + alias, + url="http://www.new-acme-corp.com", + set_after=dict(url="http://www.new-acme-corp.com"), + command=command, + ) def test_urltable_update_descr(self): - """ test updating descr of a urltable alias """ + """test updating descr of a urltable alias""" alias = dict( - name='acme_corp', url='http://www.acme-corp.com', descr='', type='urltable', updatefreq='10', detail='') + name="acme_corp", + url="http://www.acme-corp.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) command = "update alias 'acme_corp' set descr='acme corp urls'" - self.do_alias_update_field(alias, descr='acme corp urls', command=command) + self.do_alias_update_field(alias, descr="acme corp urls", command=command) def test_urltable_update_freq(self): - """ test updating updatefreq of a urltable alias """ + """test updating updatefreq of a urltable alias""" alias = dict( - name='acme_corp', url='http://www.acme-corp.com', descr='', type='urltable', updatefreq='10', detail='') + name="acme_corp", + url="http://www.acme-corp.com", + descr="", + type="urltable", + updatefreq="10", + detail="", + ) command = "update alias 'acme_corp' set updatefreq='20'" - self.do_alias_update_field(alias, updatefreq='20', command=command) + self.do_alias_update_field(alias, updatefreq="20", command=command) def test_urltable_ports_create(self): - """ test creation of a new urltable_ports alias """ - alias = dict(name='acme_table', url='http://www.acme.com', descr='', type='urltable_ports', updatefreq='10', detail='') + """test creation of a new urltable_ports alias""" + alias = dict( + name="acme_table", + url="http://www.acme.com", + descr="", + type="urltable_ports", + updatefreq="10", + detail="", + ) command = "create alias 'acme_table', type='urltable_ports', url='http://www.acme.com', descr='', detail='', updatefreq='10'" self.do_alias_creation_test(alias, command=command) @@ -248,77 +392,141 @@ def test_urltable_ports_create(self): # misc # def test_create_alias_duplicate(self): - """ test creation of a duplicate alias """ - alias = dict(name='port_ssh', address='10.0.0.1 10.0.0.2', type='host') - self.do_alias_creation_test(alias, failed=True, msg="An alias with this name and a different type already exists: 'port_ssh'") + """test creation of a duplicate alias""" + alias = dict(name="port_ssh", address="10.0.0.1 10.0.0.2", type="host") + self.do_alias_creation_test( + alias, + failed=True, + msg="An alias with this name and a different type already exists: 'port_ssh'", + ) def test_create_alias_invalid_name(self): - """ test creation of a new alias with invalid name """ - alias = dict(name='ads-ervers', address='10.0.0.1 10.0.0.2', type='host') + """test creation of a new alias with invalid name""" + alias = dict(name="ads-ervers", address="10.0.0.1 10.0.0.2", type="host") msg = "The alias name 'ads-ervers' must be less than 32 characters long, may not consist of only numbers, may not consist of only underscores, " msg += "and may only contain the following characters: a-z, A-Z, 0-9, _" self.do_alias_creation_test(alias, failed=True, msg=msg) def test_create_alias_invalid_name_interface(self): - """ test creation of a new alias with invalid name """ - alias = dict(name='lan_100', address='10.0.0.1 10.0.0.2', type='host') - self.do_alias_creation_test(alias, failed=True, msg="An interface description with this name already exists: 'lan_100'") + """test creation of a new alias with invalid name""" + alias = dict(name="lan_100", address="10.0.0.1 10.0.0.2", type="host") + self.do_alias_creation_test( + alias, + failed=True, + msg="An interface description with this name already exists: 'lan_100'", + ) def test_create_alias_invalid_updatefreq(self): - """ test creation of a new host alias with incoherent params """ - alias = dict(name='adservers', address='10.0.0.1 10.0.0.2', type='host', updatefreq=10) - self.do_alias_creation_test(alias, failed=True, msg='updatefreq is only valid with type urltable or urltable_ports') + """test creation of a new host alias with incoherent params""" + alias = dict( + name="adservers", address="10.0.0.1 10.0.0.2", type="host", updatefreq=10 + ) + self.do_alias_creation_test( + alias, + failed=True, + msg="updatefreq is only valid with type urltable or urltable_ports", + ) def test_create_alias_without_type(self): - """ test creation of a new host alias without type """ - alias = dict(name='adservers', address='10.0.0.1 10.0.0.2') - self.do_alias_creation_test(alias, failed=True, msg='state is present but all of the following are missing: type') + """test creation of a new host alias without type""" + alias = dict(name="adservers", address="10.0.0.1 10.0.0.2") + self.do_alias_creation_test( + alias, + failed=True, + msg="state is present but all of the following are missing: type", + ) def test_create_alias_without_address(self): - """ test creation of a new host alias without address """ - alias = dict(name='adservers', type='host') - self.do_alias_creation_test(alias, failed=True, msg='type is host but all of the following are missing: address') + """test creation of a new host alias without address""" + alias = dict(name="adservers", type="host") + self.do_alias_creation_test( + alias, + failed=True, + msg="type is host but all of the following are missing: address", + ) def test_create_alias_invalid_details(self): - """ test creation of a new host alias with invalid details """ - alias = dict(name='adservers', address='10.0.0.1 10.0.0.2', type='host', detail='ad1||ad2||ad3') - self.do_alias_creation_test(alias, failed=True, msg='Too many details in relation to addresses') + """test creation of a new host alias with invalid details""" + alias = dict( + name="adservers", + address="10.0.0.1 10.0.0.2", + type="host", + detail="ad1||ad2||ad3", + ) + self.do_alias_creation_test( + alias, failed=True, msg="Too many details in relation to addresses" + ) def test_create_alias_invalid_details2(self): - """ test creation of a new host alias with invalid details """ - alias = dict(name='adservers', address='10.0.0.1 10.0.0.2', type='host', detail='|ad1||ad2') - self.do_alias_creation_test(alias, failed=True, msg='Vertical bars (|) at start or end of descriptions not allowed') + """test creation of a new host alias with invalid details""" + alias = dict( + name="adservers", + address="10.0.0.1 10.0.0.2", + type="host", + detail="|ad1||ad2", + ) + self.do_alias_creation_test( + alias, + failed=True, + msg="Vertical bars (|) at start or end of descriptions not allowed", + ) def test_delete_inexistent_alias(self): - """ test deletion of an inexistent alias """ - alias = dict(name='ad_poc12', address='192.168.1.3', descr='', type='host', detail='') - with set_module_args(self.args_from_var(alias, 'absent')): + """test deletion of an inexistent alias""" + alias = dict( + name="ad_poc12", address="192.168.1.3", descr="", type="host", detail="" + ) + with set_module_args(self.args_from_var(alias, "absent")): result = self.execute_module(changed=False) diff = dict(before={}, after={}) - self.assertEqual(result['diff'], diff) - self.assertEqual(result['commands'], []) + self.assertEqual(result["diff"], diff) + self.assertEqual(result["commands"], []) def test_check_mode(self): - """ test updating an host alias without generating result """ - alias = dict(name='ad_poc1', address='192.168.1.3', descr='', type='host', detail='') - with set_module_args(self.args_from_var(alias, address='192.168.1.4', _ansible_check_mode=True)): + """test updating an host alias without generating result""" + alias = dict( + name="ad_poc1", address="192.168.1.3", descr="", type="host", detail="" + ) + with set_module_args( + self.args_from_var(alias, address="192.168.1.4", _ansible_check_mode=True) + ): result = self.execute_module(changed=True) diff = dict(before=alias, after=copy(alias)) - diff['after']['address'] = '192.168.1.4' - self.assertEqual(result['diff'], diff) + diff["after"]["address"] = "192.168.1.4" + self.assertEqual(result["diff"], diff) self.assertFalse(self.load_xml_result()) - self.assertEqual(result['commands'], ["update alias 'ad_poc1' set address='192.168.1.4'"]) + self.assertEqual( + result["commands"], ["update alias 'ad_poc1' set address='192.168.1.4'"] + ) def test_urltable_required_if(self): - """ test creation of a new urltable alias without giving updatefreq (should fail) """ - alias = dict(name='acme_table', address='http://www.acme.com', descr='', type='urltable', detail='') + """test creation of a new urltable alias without giving updatefreq (should fail)""" + alias = dict( + name="acme_table", + address="http://www.acme.com", + descr="", + type="urltable", + detail="", + ) with set_module_args(self.args_from_var(alias)): - self.execute_module(failed=True, msg='type is urltable but all of the following are missing: updatefreq') + self.execute_module( + failed=True, + msg="type is urltable but all of the following are missing: updatefreq", + ) def test_urltable_ports_required_if(self): - """ test creation of a new urltable_ports alias without giving updatefreq (should fail) """ - alias = dict(name='acme_table', address='http://www.acme.com', descr='', type='urltable_ports', detail='') + """test creation of a new urltable_ports alias without giving updatefreq (should fail)""" + alias = dict( + name="acme_table", + address="http://www.acme.com", + descr="", + type="urltable_ports", + detail="", + ) with set_module_args(self.args_from_var(alias)): - self.execute_module(failed=True, msg='type is urltable_ports but all of the following are missing: updatefreq') + self.execute_module( + failed=True, + msg="type is urltable_ports but all of the following are missing: updatefreq", + ) diff --git a/tests/unit/plugins/modules/test_pfsense_alias_null.py b/tests/unit/plugins/modules/test_pfsense_alias_null.py index 9caf4ad7..a8b24798 100644 --- a/tests/unit/plugins/modules/test_pfsense_alias_null.py +++ b/tests/unit/plugins/modules/test_pfsense_alias_null.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -10,21 +11,24 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import set_module_args +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + set_module_args, +) from ansible_collections.pfsensible.core.plugins.modules import pfsense_alias -from ansible_collections.pfsensible.core.plugins.module_utils.alias import PFSenseAliasModule +from ansible_collections.pfsensible.core.plugins.module_utils.alias import ( + PFSenseAliasModule, +) from .pfsense_module import TestPFSenseModule # Test alias creation starting without an initial element class TestPFSenseAliasNullModule(TestPFSenseModule): - module = pfsense_alias def __init__(self, *args, **kwargs): super(TestPFSenseAliasNullModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_alias_null_config.xml' + self.config_file = "pfsense_alias_null_config.xml" self.pfmodule = PFSenseAliasModule ######################################################## @@ -32,16 +36,20 @@ def __init__(self, *args, **kwargs): # First we run the module # Then, we check return values # Finally, we check the xml - def do_alias_creation_test(self, alias, failed=False, msg='', command=None): - """ test creation of a new alias """ + def do_alias_creation_test(self, alias, failed=False, msg="", command=None): + """test creation of a new alias""" with set_module_args(self.args_from_var(alias)): result = self.execute_module(changed=True, failed=failed, msg=msg) if not failed: diff = dict(before={}, after=alias) - self.assertEqual(result['diff'], diff) - self.assert_xml_elt_dict('aliases', dict(name=alias['name'], type=alias['type']), diff['after']) - self.assertEqual(result['commands'], [command]) + self.assertEqual(result["diff"], diff) + self.assert_xml_elt_dict( + "aliases", + dict(name=alias["name"], type=alias["type"]), + diff["after"], + ) + self.assertEqual(result["commands"], [command]) else: self.assertFalse(self.load_xml_result()) @@ -49,7 +57,13 @@ def do_alias_creation_test(self, alias, failed=False, msg='', command=None): # hosts # def test_host_create(self): - """ test creation of a new host alias """ - alias = dict(name='adservers', address='10.0.0.1 10.0.0.2', descr='', type='host', detail='') + """test creation of a new host alias""" + alias = dict( + name="adservers", + address="10.0.0.1 10.0.0.2", + descr="", + type="host", + detail="", + ) command = "create alias 'adservers', type='host', address='10.0.0.1 10.0.0.2', descr='', detail=''" self.do_alias_creation_test(alias, command=command) diff --git a/tests/unit/plugins/modules/test_pfsense_authserver_ldap.py b/tests/unit/plugins/modules/test_pfsense_authserver_ldap.py index 31e726ad..4070916c 100644 --- a/tests/unit/plugins/modules/test_pfsense_authserver_ldap.py +++ b/tests/unit/plugins/modules/test_pfsense_authserver_ldap.py @@ -1,7 +1,8 @@ # Copyright: (c) 2022, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -15,92 +16,189 @@ class TestPFSenseAuthserverLDAPModule(TestPFSenseModule): - module = pfsense_authserver_ldap def __init__(self, *args, **kwargs): super(TestPFSenseAuthserverLDAPModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_authserver_config.xml' + self.config_file = "pfsense_authserver_config.xml" self.pfmodule = pfsense_authserver_ldap.PFSenseAuthserverLDAPModule @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ - root_elt = self.assert_find_xml_elt(self.xml_result, 'system') - result = root_elt.findall("authserver[name='{0}']".format(obj['name'])) + """return target elt from XML""" + root_elt = self.assert_find_xml_elt(self.xml_result, "system") + result = root_elt.findall("authserver[name='{0}']".format(obj["name"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple authservers for name {0}.'.format(obj['name'])) + self.fail("Found multiple authservers for name {0}.".format(obj["name"])) else: return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ - - urltype = dict({'tcp': 'Standard TCP', 'starttls': 'STARTTLS Encrypted', 'ssl': 'SSL/TLS Encrypted'}) - self.check_param_equal(obj, target_elt, 'name') - self.assert_xml_elt_match(target_elt, 'refid', r'[0-9a-f]{13}') - self.assert_xml_elt_equal(target_elt, 'type', 'ldap') - self.check_param_equal(obj, target_elt, 'ldap_caref', default='global') - self.check_param_equal(obj, target_elt, 'host') - self.check_param_equal(obj, target_elt, 'port', xml_field='ldap_port', default=389) - self.assert_xml_elt_equal(target_elt, 'ldap_urltype', urltype[obj['transport']]) - self.check_param_equal(obj, target_elt, 'protover', xml_field='ldap_protver', default=3) - self.check_param_equal(obj, target_elt, 'scope', xml_field='ldap_scope', default='one') - self.check_param_equal(obj, target_elt, 'basedn', xml_field='ldap_basedn', default=None) - self.check_param_equal(obj, target_elt, 'authcn', xml_field='ldap_authcn') - self.check_param_bool(obj, target_elt, 'extended_enabled', xml_field='ldap_extended_enabled', value_true='yes') - self.check_param_equal(obj, target_elt, 'extended_query', xml_field='ldap_extended_query') - self.check_param_equal(obj, target_elt, 'attr_user', xml_field='ldap_attr_user', default='cn') - self.check_param_equal(obj, target_elt, 'attr_group', xml_field='ldap_attr_group', default='cn') - self.check_param_equal(obj, target_elt, 'attr_member', xml_field='ldap_attr_member', default='member') - self.check_param_equal(obj, target_elt, 'attr_groupobj', xml_field='ldap_attr_groupobj', default='posixGroup') - self.check_param_equal(obj, target_elt, 'pam_groupdn', xml_field='ldap_pam_groupdn', default=None) - self.check_param_bool(obj, target_elt, 'ldap_allow_unauthenticated', xml_field='ldap_allow_unauthenticated', default=True) - self.check_param_equal(obj, target_elt, 'timeout', xml_field='ldap_timeout', default=25) + """check XML definition of target elt""" + + urltype = dict( + { + "tcp": "Standard TCP", + "starttls": "STARTTLS Encrypted", + "ssl": "SSL/TLS Encrypted", + } + ) + self.check_param_equal(obj, target_elt, "name") + self.assert_xml_elt_match(target_elt, "refid", r"[0-9a-f]{13}") + self.assert_xml_elt_equal(target_elt, "type", "ldap") + self.check_param_equal(obj, target_elt, "ldap_caref", default="global") + self.check_param_equal(obj, target_elt, "host") + self.check_param_equal( + obj, target_elt, "port", xml_field="ldap_port", default=389 + ) + self.assert_xml_elt_equal(target_elt, "ldap_urltype", urltype[obj["transport"]]) + self.check_param_equal( + obj, target_elt, "protover", xml_field="ldap_protver", default=3 + ) + self.check_param_equal( + obj, target_elt, "scope", xml_field="ldap_scope", default="one" + ) + self.check_param_equal( + obj, target_elt, "basedn", xml_field="ldap_basedn", default=None + ) + self.check_param_equal(obj, target_elt, "authcn", xml_field="ldap_authcn") + self.check_param_bool( + obj, + target_elt, + "extended_enabled", + xml_field="ldap_extended_enabled", + value_true="yes", + ) + self.check_param_equal( + obj, target_elt, "extended_query", xml_field="ldap_extended_query" + ) + self.check_param_equal( + obj, target_elt, "attr_user", xml_field="ldap_attr_user", default="cn" + ) + self.check_param_equal( + obj, target_elt, "attr_group", xml_field="ldap_attr_group", default="cn" + ) + self.check_param_equal( + obj, + target_elt, + "attr_member", + xml_field="ldap_attr_member", + default="member", + ) + self.check_param_equal( + obj, + target_elt, + "attr_groupobj", + xml_field="ldap_attr_groupobj", + default="posixGroup", + ) + self.check_param_equal( + obj, target_elt, "pam_groupdn", xml_field="ldap_pam_groupdn", default=None + ) + self.check_param_bool( + obj, + target_elt, + "ldap_allow_unauthenticated", + xml_field="ldap_allow_unauthenticated", + default=True, + ) + self.check_param_equal( + obj, target_elt, "timeout", xml_field="ldap_timeout", default=25 + ) ############## # tests # def test_authserver_create(self): - """ test creation of a new authserver """ - obj = dict(name='authserver1', host='ldap.example.com', transport='tcp', scope='one', authcn='CN=Users') - self.do_module_test(obj, command="create authserver_ldap 'authserver1', host='ldap.example.com'") + """test creation of a new authserver""" + obj = dict( + name="authserver1", + host="ldap.example.com", + transport="tcp", + scope="one", + authcn="CN=Users", + ) + self.do_module_test( + obj, command="create authserver_ldap 'authserver1', host='ldap.example.com'" + ) def test_authserver_delete(self): - """ test deletion of a authserver """ - obj = dict(name='DELLDAP') - self.do_module_test(obj, command="delete authserver_ldap 'DELLDAP'", delete=True) + """test deletion of a authserver""" + obj = dict(name="DELLDAP") + self.do_module_test( + obj, command="delete authserver_ldap 'DELLDAP'", delete=True + ) def test_authserver_update_noop(self): - """ test not updating a authserver """ - obj = dict(name='DELLDAP', host='ldap.example.com', transport='tcp', scope='one', authcn='CN=Users', timeout=25) - self.do_module_test(obj, command="delete authserver_ldap 'DELLDAP'", changed=False) + """test not updating a authserver""" + obj = dict( + name="DELLDAP", + host="ldap.example.com", + transport="tcp", + scope="one", + authcn="CN=Users", + timeout=25, + ) + self.do_module_test( + obj, command="delete authserver_ldap 'DELLDAP'", changed=False + ) def test_authserver_update_host(self): - """ test updating host of a authserver """ - obj = dict(name='DELLDAP', ldap_timeout=5, host='ldap2.blah.com', transport='tcp', scope='one', authcn='CN=Users') - self.do_module_test(obj, command="update authserver_ldap 'DELLDAP' set host='ldap2.blah.com'") + """test updating host of a authserver""" + obj = dict( + name="DELLDAP", + ldap_timeout=5, + host="ldap2.blah.com", + transport="tcp", + scope="one", + authcn="CN=Users", + ) + self.do_module_test( + obj, command="update authserver_ldap 'DELLDAP' set host='ldap2.blah.com'" + ) def test_authserver_disable_allow_unauthenticated(self): - """ test disabling ldap_allow_unauthenticated """ - obj = dict(name='DELLDAP', host='ldap.example.com', transport='tcp', scope='one', authcn='CN=Users', ldap_allow_unauthenticated=False) - self.do_module_test(obj, command="update authserver_ldap 'DELLDAP' set ldap_allow_unauthenticated=False") + """test disabling ldap_allow_unauthenticated""" + obj = dict( + name="DELLDAP", + host="ldap.example.com", + transport="tcp", + scope="one", + authcn="CN=Users", + ldap_allow_unauthenticated=False, + ) + self.do_module_test( + obj, + command="update authserver_ldap 'DELLDAP' set ldap_allow_unauthenticated=False", + ) ############## # misc # def test_create_authserver_invalid_timeout(self): - """ test creation of a new authserver with invalid timeout """ - obj = dict(name='DELLDAP', host='ldap.example.com', transport='tcp', scope='one', authcn='CN=Users', timeout=0) - self.do_module_test(obj, command="update authserver_ldap 'DELLDAP'", failed=True, msg='timeout 0 must be greater than 1') + """test creation of a new authserver with invalid timeout""" + obj = dict( + name="DELLDAP", + host="ldap.example.com", + transport="tcp", + scope="one", + authcn="CN=Users", + timeout=0, + ) + self.do_module_test( + obj, + command="update authserver_ldap 'DELLDAP'", + failed=True, + msg="timeout 0 must be greater than 1", + ) def test_delete_inexistent_authserver(self): - """ test deletion of an inexistent authserver """ - obj = dict(name='noauthserver') - self.do_module_test(obj, state='absent', changed=False) + """test deletion of an inexistent authserver""" + obj = dict(name="noauthserver") + self.do_module_test(obj, state="absent", changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_authserver_radius.py b/tests/unit/plugins/modules/test_pfsense_authserver_radius.py index cc4d879f..3ee14ce6 100644 --- a/tests/unit/plugins/modules/test_pfsense_authserver_radius.py +++ b/tests/unit/plugins/modules/test_pfsense_authserver_radius.py @@ -1,7 +1,8 @@ # Copyright: (c) 2022, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -10,82 +11,122 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.pfsensible.core.plugins.modules import pfsense_authserver_radius +from ansible_collections.pfsensible.core.plugins.modules import ( + pfsense_authserver_radius, +) from .pfsense_module import TestPFSenseModule class TestPFSenseAuthserverRADIUSModule(TestPFSenseModule): - module = pfsense_authserver_radius def __init__(self, *args, **kwargs): super(TestPFSenseAuthserverRADIUSModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_authserver_config.xml' + self.config_file = "pfsense_authserver_config.xml" self.pfmodule = pfsense_authserver_radius.PFSenseAuthserverRADIUSModule @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ - root_elt = self.assert_find_xml_elt(self.xml_result, 'system') - result = root_elt.findall("authserver[name='{0}']".format(obj['name'])) + """return target elt from XML""" + root_elt = self.assert_find_xml_elt(self.xml_result, "system") + result = root_elt.findall("authserver[name='{0}']".format(obj["name"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple authservers for name {0}.'.format(obj['name'])) + self.fail("Found multiple authservers for name {0}.".format(obj["name"])) else: return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ - - urltype = dict({'tcp': 'Standard TCP', 'starttls': 'STARTTLS Encrypted', 'ssl': 'SSL/TLS Encrypted'}) - self.check_param_equal(obj, target_elt, 'name') - self.assert_xml_elt_match(target_elt, 'refid', r'[0-9a-f]{13}') - self.assert_xml_elt_equal(target_elt, 'type', 'radius') - self.check_param_equal(obj, target_elt, 'host') - self.check_param_equal(obj, target_elt, 'auth_port', xml_field='radius_auth_port', default=1812) - self.check_param_equal(obj, target_elt, 'acct_port', xml_field='radius_acct_port', default=1813) - self.check_param_equal(obj, target_elt, 'protocol', xml_field='radius_protocol', default='MSCHAPv2') - self.check_param_equal(obj, target_elt, 'secret', xml_field='radius_secret') - self.check_param_equal(obj, target_elt, 'timeout', xml_field='radius_timeout', default=5) - self.check_param_equal(obj, target_elt, 'nasip_attribute', xml_field='radius_nasip_attribute', default='lan') + """check XML definition of target elt""" + + urltype = dict( + { + "tcp": "Standard TCP", + "starttls": "STARTTLS Encrypted", + "ssl": "SSL/TLS Encrypted", + } + ) + self.check_param_equal(obj, target_elt, "name") + self.assert_xml_elt_match(target_elt, "refid", r"[0-9a-f]{13}") + self.assert_xml_elt_equal(target_elt, "type", "radius") + self.check_param_equal(obj, target_elt, "host") + self.check_param_equal( + obj, target_elt, "auth_port", xml_field="radius_auth_port", default=1812 + ) + self.check_param_equal( + obj, target_elt, "acct_port", xml_field="radius_acct_port", default=1813 + ) + self.check_param_equal( + obj, target_elt, "protocol", xml_field="radius_protocol", default="MSCHAPv2" + ) + self.check_param_equal(obj, target_elt, "secret", xml_field="radius_secret") + self.check_param_equal( + obj, target_elt, "timeout", xml_field="radius_timeout", default=5 + ) + self.check_param_equal( + obj, + target_elt, + "nasip_attribute", + xml_field="radius_nasip_attribute", + default="lan", + ) ############## # tests # def test_authserver_create(self): - """ test creation of a new authserver """ - obj = dict(name='authserver1', host='radius.example.com', secret='password1') + """test creation of a new authserver""" + obj = dict(name="authserver1", host="radius.example.com", secret="password1") self.do_module_test(obj, command="create authserver_radius 'authserver1'") def test_authserver_delete(self): - """ test deletion of a authserver """ - obj = dict(name='DELRADIUS') - self.do_module_test(obj, command="delete authserver_radius 'DELRADIUS'", delete=True) + """test deletion of a authserver""" + obj = dict(name="DELRADIUS") + self.do_module_test( + obj, command="delete authserver_radius 'DELRADIUS'", delete=True + ) def test_authserver_update_noop(self): - """ test not updating a authserver """ - obj = dict(name='DELRADIUS', host='radius.example.com', secret='password1', auth_port=1812) + """test not updating a authserver""" + obj = dict( + name="DELRADIUS", + host="radius.example.com", + secret="password1", + auth_port=1812, + ) self.do_module_test(obj, changed=False) def test_authserver_update_host(self): - """ test updating host of a authserver """ - obj = dict(name='DELRADIUS', radius_timeout=25, host='radius2.blah.com', secret='password2') + """test updating host of a authserver""" + obj = dict( + name="DELRADIUS", + radius_timeout=25, + host="radius2.blah.com", + secret="password2", + ) self.do_module_test(obj, command="update authserver_radius 'DELRADIUS' set ") ############## # misc # def test_create_authserver_invalid_timeout(self): - """ test creation of a new authserver with invalid timeout """ - obj = dict(name='DELRADIUS', host='radius.example.com', secret='password1', timeout=0) - self.do_module_test(obj, command="update authserver_radius 'DELRADIUS'", failed=True, msg='timeout 0 must be greater than 1') + """test creation of a new authserver with invalid timeout""" + obj = dict( + name="DELRADIUS", host="radius.example.com", secret="password1", timeout=0 + ) + self.do_module_test( + obj, + command="update authserver_radius 'DELRADIUS'", + failed=True, + msg="timeout 0 must be greater than 1", + ) def test_delete_inexistent_authserver(self): - """ test deletion of an inexistent authserver """ - obj = dict(name='noauthserver') - self.do_module_test(obj, state='absent', changed=False) + """test deletion of an inexistent authserver""" + obj = dict(name="noauthserver") + self.do_module_test(obj, state="absent", changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_ca.py b/tests/unit/plugins/modules/test_pfsense_ca.py index 7657ff2b..e46bbf8b 100644 --- a/tests/unit/plugins/modules/test_pfsense_ca.py +++ b/tests/unit/plugins/modules/test_pfsense_ca.py @@ -1,7 +1,8 @@ # Copyright: (c) 2022, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -12,7 +13,9 @@ from ansible_collections.pfsensible.core.plugins.modules import pfsense_ca from .pfsense_module import TestPFSenseModule -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) CERTIFICATE = ( "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVDRENDQXZDZ0F3SUJBZ0lJRmpGT2hzMW5NelF3RFFZSktvWklodmNOQVFFTEJRQXdYREVUTUJFR0ExVUUKQXhNS2IzQmxiblp3YmkxallURUxN" @@ -27,7 +30,8 @@ "VkhSTUVCVEFEQVFIL01Bc0dBMVVkRHdRRUF3SUJCakFOQmdrcWhraUcKOXcwQkFRc0ZBQU9DQVFFQVVIOUtDZG1KZG9BSmxVMHdCSkhZeGpMcktsbFBZNk9OYnpyNUpiaENNNjlIeHhZTgpCa2lpbXd1" "N09mRmFGZkZDT25NSjhvcStKVGxjMG9vREoxM2xCdHRONkdybnZrUTNQMXdZYkNFTmJuaWxPYVVCClRJcmlIeXRORFFhb3VOYS9LV3M3RmF1b2JjdEJsMXc5YXRvSFpzTjVvZWhUM3JBVHYxQ0NBdGpw" "YVRKSWZKUjMKMElRT1lrZTRvWTZEa0l3SHAydlBQbW9vR2dJdGJUdzNVK0U0MVlaZTdxQ21FLzd6TFRTWmtJTTJseDZ6RDQ2agpEZjRyZ044TVVMNnhpd09MbzlyQUp5ckRNM2JEeTJ1QjY0QkVzRFFM" - "a2huUE92ZWtETjQ1NnV6TmpYS0E3VnE4CmgxL2d6RFpJRGkrV1hDWUFjYmdMaFpWQnF0bjYydW1GcE1SSXV3PT0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=") + "a2huUE92ZWtETjQ1NnV6TmpYS0E3VnE4CmgxL2d6RFpJRGkrV1hDWUFjYmdMaFpWQnF0bjYydW1GcE1SSXV3PT0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=" +) CRL1 = ( "LS0tLS1CRUdJTiBYNTA5IENSTC0tLS0tCk1JSUNkRENDQVZ3Q0FRRXdEUVlKS29aSWh2Y05BUUVGQlFBd1hERVRNQkVHQTFVRUF4TUtiM0JsYm5ad2JpMWoKWVRFTE1Ba0dBMVVFQmhNQ1ZWTXhFVEFQ" "QmdOVkJBZ1RDRU52Ykc5eVlXUnZNUkF3RGdZRFZRUUhFd2RDYjNWcwpaR1Z5TVJNd0VRWURWUVFLRXdwd1psTmxibk5wWW14bEZ3MHlNakF5TVRrd05UVXhNRFphRncwME9UQTNNRFl3Ck5UVXhNRFph" @@ -36,7 +40,8 @@ "UVFLRXdwd1psTmxibk5wWW14bGdnZ1cKTVU2R3pXY3pOREFMQmdOVkhSUUVCQUlDSnhFd0RRWUpLb1pJaHZjTkFRRUZCUUFEZ2dFQkFGbXJ5cFUxU3p5dApNUUZCRWFZZk9waVpqRVhVajE5MVZuWENl" "b0tNMk83bVUzYW5HVXRZQUJMcG15dmN2YnU2ZkJCVEtYSTFEb0VvClJkV1VDTVMxbk5BTWwyU0N0ZmJ5RHNHNjZHczRiNnRZeXE1SW5LVFJJdldUeU5vS0JiUHc1OHZYV0ljNmVmUXgKSTYvZSt4U3di" "eE9MSFlRdGd4WTJOdk9xVGVnVE0rTHpIcmNJWmFPS09NbHNodTA4ajgzSnUxR0ttYlBKME1jZwpyVXNiYXRKcURUdWtQMi9VbmI0N1hwN21qUHVTY0Z5MjN2RGl2OHdvcjBYOEFSQW1ibTN4N2ZKeTlt" - "V2d1OVhMCmpNV1lxN1BEaXhwWElqTVdhZzN2bVYxOC9IdDIybW1xS1RPM3prVnJLUDA1TEhCNVloM2ZZcEpWdEhkeENlTzUKdmlvbU53SzA3QUE9Ci0tLS0tRU5EIFg1MDkgQ1JMLS0tLS0=") + "V2d1OVhMCmpNV1lxN1BEaXhwWElqTVdhZzN2bVYxOC9IdDIybW1xS1RPM3prVnJLUDA1TEhCNVloM2ZZcEpWdEhkeENlTzUKdmlvbU53SzA3QUE9Ci0tLS0tRU5EIFg1MDkgQ1JMLS0tLS0=" +) CRL2 = ( "-----BEGIN X509 CRL-----\n" "MIICSDCCATACAQEwDQYJKoZIhvcNAQEFBQAwXDETMBEGA1UEAxMKb3BlbnZwbi1j\n" @@ -52,101 +57,111 @@ "Ny5w9dLF4s+6qFXjvfYmQ0FyeRcltUoF3kTabS1WCdkGjsUSeGHBFLM4NH2mJPMR\n" "0yfIGdipSonSTF51ICqgoUGAYPqObvlQZDMjFF+GFL3LNQ7gO+1R1OMMKAZ+96nX\n" "gwt+00UVYhQCCZ3k\n" - "-----END X509 CRL-----\n") + "-----END X509 CRL-----\n" +) class TestPFSenseCAModule(TestPFSenseModule): - module = pfsense_ca def __init__(self, *args, **kwargs): super(TestPFSenseCAModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_ca_config.xml' + self.config_file = "pfsense_ca_config.xml" self.pfmodule = pfsense_ca.PFSenseCAModule def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseCAModule, self).setUp() - self.mock_php = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php') + self.mock_php = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php" + ) self.php = self.mock_php.start() - self.php.return_value = '12000' + self.php.return_value = "12000" @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ + """return target elt from XML""" root_elt = self.xml_result.getroot() - result = root_elt.findall("ca[descr='{0}']".format(obj['name'])) + result = root_elt.findall("ca[descr='{0}']".format(obj["name"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple CAs for name {0}.'.format(obj['name'])) + self.fail("Found multiple CAs for name {0}.".format(obj["name"])) else: return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ - - self.check_param_equal(obj, target_elt, 'name', xml_field='descr') - if 'trust' in obj: - self.check_param_bool(obj, target_elt, 'trust', value_true='enabled', value_false='disabled') - if 'randomserial' in obj: - self.check_param_bool(obj, target_elt, 'randomserial', value_true='enabled', value_false='disabled') - self.check_param_equal_or_present(obj, target_elt, 'serial') - self.check_param_equal(obj, target_elt, 'certificate', xml_field='crt') + """check XML definition of target elt""" + + self.check_param_equal(obj, target_elt, "name", xml_field="descr") + if "trust" in obj: + self.check_param_bool( + obj, target_elt, "trust", value_true="enabled", value_false="disabled" + ) + if "randomserial" in obj: + self.check_param_bool( + obj, + target_elt, + "randomserial", + value_true="enabled", + value_false="disabled", + ) + self.check_param_equal_or_present(obj, target_elt, "serial") + self.check_param_equal(obj, target_elt, "certificate", xml_field="crt") ############## # tests # def test_ca_create(self): - """ test creation of a new ca """ - obj = dict(name='ca1', certificate=CERTIFICATE) + """test creation of a new ca""" + obj = dict(name="ca1", certificate=CERTIFICATE) self.do_module_test(obj, command="create ca 'ca1'") def test_ca_add_crl(self): - """ test adding a CRL """ - obj = dict(name='ca1', certificate=CERTIFICATE, crl=CRL1) + """test adding a CRL""" + obj = dict(name="ca1", certificate=CERTIFICATE, crl=CRL1) self.do_module_test(obj, command="create ca 'ca1'") def test_ca_change_crl(self): - """ test adding a CRL """ - obj = dict(name='ca1', certificate=CERTIFICATE, crl=CRL2) + """test adding a CRL""" + obj = dict(name="ca1", certificate=CERTIFICATE, crl=CRL2) self.do_module_test(obj, command="create ca 'ca1'") def test_ca_delete(self): - """ test deletion of a ca """ - obj = dict(name='testdel') + """test deletion of a ca""" + obj = dict(name="testdel") self.do_module_test(obj, command="delete ca 'testdel'", delete=True) def test_ca_update_noop(self): - """ test not updating a ca """ - obj = dict(name='testdel', certificate=CERTIFICATE) + """test not updating a ca""" + obj = dict(name="testdel", certificate=CERTIFICATE) self.do_module_test(obj, changed=False) def test_ca_update_serial(self): - """ test updating serial of a ca """ - obj = dict(name='testdel', certificate=CERTIFICATE, serial=10) + """test updating serial of a ca""" + obj = dict(name="testdel", certificate=CERTIFICATE, serial=10) self.do_module_test(obj, command="update ca 'testdel' set serial='10'") def test_ca_update_trust(self): - """ test updating trust of a ca """ - obj = dict(name='testdel', certificate=CERTIFICATE, trust=False) + """test updating trust of a ca""" + obj = dict(name="testdel", certificate=CERTIFICATE, trust=False) self.do_module_test(obj, command="update ca 'testdel' set ") ############## # misc # def test_create_ca_invalid_serial(self): - """ test creation of a new ca with invalid serial """ - obj = dict(name='ca1', certificate=CERTIFICATE, serial=-1) - self.do_module_test(obj, failed=True, msg='serial must be greater than 0') + """test creation of a new ca with invalid serial""" + obj = dict(name="ca1", certificate=CERTIFICATE, serial=-1) + self.do_module_test(obj, failed=True, msg="serial must be greater than 0") def test_delete_nonexistent_ca(self): - """ test deletion of an nonexistent ca """ - obj = dict(name='noca') - self.do_module_test(obj, commmand=None, state='absent', changed=False) + """test deletion of an nonexistent ca""" + obj = dict(name="noca") + self.do_module_test(obj, commmand=None, state="absent", changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_cert.py b/tests/unit/plugins/modules/test_pfsense_cert.py index fc013367..6bb9a7fc 100644 --- a/tests/unit/plugins/modules/test_pfsense_cert.py +++ b/tests/unit/plugins/modules/test_pfsense_cert.py @@ -1,7 +1,8 @@ # Copyright: (c) 2025, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -26,7 +27,8 @@ "VkhSTUVCVEFEQVFIL01Bc0dBMVVkRHdRRUF3SUJCakFOQmdrcWhraUcKOXcwQkFRc0ZBQU9DQVFFQVVIOUtDZG1KZG9BSmxVMHdCSkhZeGpMcktsbFBZNk9OYnpyNUpiaENNNjlIeHhZTgpCa2lpbXd1" "N09mRmFGZkZDT25NSjhvcStKVGxjMG9vREoxM2xCdHRONkdybnZrUTNQMXdZYkNFTmJuaWxPYVVCClRJcmlIeXRORFFhb3VOYS9LV3M3RmF1b2JjdEJsMXc5YXRvSFpzTjVvZWhUM3JBVHYxQ0NBdGpw" "YVRKSWZKUjMKMElRT1lrZTRvWTZEa0l3SHAydlBQbW9vR2dJdGJUdzNVK0U0MVlaZTdxQ21FLzd6TFRTWmtJTTJseDZ6RDQ2agpEZjRyZ044TVVMNnhpd09MbzlyQUp5ckRNM2JEeTJ1QjY0QkVzRFFM" - "a2huUE92ZWtETjQ1NnV6TmpYS0E3VnE4CmgxL2d6RFpJRGkrV1hDWUFjYmdMaFpWQnF0bjYydW1GcE1SSXV3PT0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=") + "a2huUE92ZWtETjQ1NnV6TmpYS0E3VnE4CmgxL2d6RFpJRGkrV1hDWUFjYmdMaFpWQnF0bjYydW1GcE1SSXV3PT0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=" +) TESTDEL_KEY = ( "LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tDQpNSUlFdlFJQkFEQU5CZ2txaGtpRzl3MEJBUUVGQUFTQ0JLY3dnZ1NqQWdFQUFvSUJBUUNheStJa3dUVVJONTNoDQo4NjF1UWVrWGQzMTR6N3JqVHhj" "K2RyeTBrWjBHcmRRQUxFVE9ob0VwRzNiTFZ5M3JOQU43a1VYZ0FDKzFYVzV4DQptbXlYbWpnTkw4Z1pKdWZwM2RneVQ2UHlNQkRjU2JOSHdZNmlHZUIvVkQwQmNMcWlIaXNFL214ZmhqeTVqeVFuDQpk" @@ -43,7 +45,8 @@ "d09Gbzh2cmZJSXNwRTNnamh5MTZCbGsyUUFRMGJwbGpBVFljVDNDWWhUZEU1DQpFNkZwRzVNNllCTXJ6YUxwc1JDVzFtZjJnLzYzelhNMzJUVXJFdFJyRVdGUE84TUI0blF0Y1Y5a2pFa3hGNHRWDQpD" "dGp3YjI5MEtNUFNDS00wS08vSTVDUXdpTFAydUtkeTBSRkpnRHUxQW9HQUxqSTdGZDl0ckFrY1dSenMyVHpRDQptaDRTZWxHRDFvdFAxRXpFa1hYVlZwK1lMNXlxOWM3V3hsa09RY1lFVmd1N2huNVFn" "cnBRUFZPTFVmbkJoajdXDQo1MC9IVmR6V21wSXg5NUlqZXFDZklBV1U3N0I5cmJUR2hvWWMzbTdJcEdObzl2WlhHYWgrc2JHY3BEK1phV3UzDQp1Q25pTnJpZEhORGgzWHZQVFZkRTRlVT0NCi0tLS0t" - "RU5EIFBSSVZBVEUgS0VZLS0tLS0NCg==") + "RU5EIFBSSVZBVEUgS0VZLS0tLS0NCg==" +) WEB_CRT = ( "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVsRENDQTN5Z0F3SUJBZ0lJU2JlZ00zSWZSTEV3RFFZSktvWklodmNOQVFFTEJRQXdXakU0TURZR0ExVUUKQ2hNdmNHWlRaVzV6WlNCM1pXSkRi" "MjVtYVdkMWNtRjBiM0lnVTJWc1ppMVRhV2R1WldRZ1EyVnlkR2xtYVdOaApkR1V4SGpBY0JnTlZCQU1URlhCbVUyVnVjMlV0TmpJd09ETTJOemxqWkROa05EQWVGdzB5TWpBeU1USXlNak0yCk5ERmFG" @@ -59,7 +62,8 @@ "U0liM0RRRUIKQ3dVQUE0SUJBUUFzazBrTU12dVR0T3c2Ymx5a1U5cWNkRnQvVDlGOFZBZ0taNHgzYXNxNlArRG96N1FGVFpKVwprdmlrQVVUekpMMys4c0NKRDdjV3BZa2ZpdDRBYndhWFIyRzVsczhj" "L0JRcUdmY1ZOUnJVdWRscG12UUYrYk5iClMxZ2xjS2hYYXZuYnlQdkRMem9CZGVlTmhqYXIzcWc1TTV6T3I0aXYyM0hCZVc2aEY2c0FrV3dpVkU5NEJmZ00KOS9qeW5GalVYTkJheStMODM2TXBpNDhp" "NnE4OHdlQ25UdDdaTFFjWlZXb0IwcWNQSS96SExTUFlTNlhhcmdvdgpva3E1M3ZQSG9HNnRGUHpFSkpFVmNmOTV1bVcwaUpFR3hCQ3dTeVlnd2xSY0pEeGJ1QklFY0xWb2JKclVveHNLClJXcW13SHdQ" - "YkFxRjBOMUZ0cFJ6K3Yvd0lQYWdSQ2lVCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K") + "YkFxRjBOMUZ0cFJ6K3Yvd0lQYWdSQ2lVCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K" +) WEB_KEY = ( "LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JSUV2d0lCQURBTkJna3Foa2lHOXcwQkFRRUZBQVNDQktrd2dnU2xBZ0VBQW9JQkFRREc5M1VuTzJtTG00RVUKNHA5S2UwK2ZuRkJBNkZXbjl5MENy" "FpjR3FUVmI3QTd6NjhjUkF3SXhsRUNia1ZpTHV6N0pzVCtoMHdSYnJaNwpFYjV2Sm9nclpQdlJXUHJ5b0lBbCtPbDVvV0VsTTJxMFAxUlB2NkROd1piMWxQeDM1d1c4QWtYMUJISy9xRFliCmJuNVVFd" @@ -75,77 +79,86 @@ "TJRblR3Q1UKUk8xTkR5MUx3cEVQQjlvWkE4SkoydCtudTlXTWdmV2dxODJZZFhlSWxxT2JyejZKTitOTjB2R0ZEdTBDZ1lCcQp5a1p3anNVV2ZCdEVhZFVyanQ4aTJxNGYzbTBxY3kzY0pjRzVGbGV6T" "3JUaEpjN0ZRdndSZHhYQ213YUhBMDNVCktxQmV3YnhYSWo5TWcwWk8yMG4wbEdyYVdMVDNKTndBbmtrQXZ5VjVoSWlPTVBNMm8wN1JPNjVJTzdKallKNnIKcUdVVnZnenRBdzVJSXdST29Edjc2UHdxT" "HJpS3VLVmY4c0wrcDRMTlhRS0JnUUNrbXlIR3dielVJLzN6Z1FNSApxY1RkMUttMDhxTXJjUzZXdU03RGJxTEpIQXdlcFRSYUpuVmVLcnRRS2t4SGRlcjZsK2VWUjNvWXArUE9EbTVECjBGSXMxbXd2R" - "Fp3TjI2RVhmZnZkWG9EK1luNnVEeGlvVU9QZ1A2U1hLT3dxcnBuenVFZTFBNFpDQTFnRXJhd2MKTTNmejdiV3d0a1JUUE5uaGxybkVJNXY5Qmc9PQotLS0tLUVORCBQUklWQVRFIEtFWS0tLS0tCg==") + "Fp3TjI2RVhmZnZkWG9EK1luNnVEeGlvVU9QZ1A2U1hLT3dxcnBuenVFZTFBNFpDQTFnRXJhd2MKTTNmejdiV3d0a1JUUE5uaGxybkVJNXY5Qmc9PQotLS0tLUVORCBQUklWQVRFIEtFWS0tLS0tCg==" +) RSA_KEY = ( "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlCUFFJQkFBSkJBTmNObWdtNFlsU1VBcjJ4ZFdlaTVhUlUvRGJXdHNRNDdnamt2MjhFa2plM29iKzZxME0rCkQ1cGh3WURjdjl5Z1ltdUo1" "d09pMWNQcHJzV2RGV212U3VzQ0F3RUFBUUpCQUtRQ3paM29MNllOaytHVU85UWsKV2p0d1RVS05rcW9vT1BJemN3UjZXZ0YrOXEydlNIaTQxLzZmdjFOaDJQZU91ZDcvZHFxTklLbGxGZXdIYnJsbApp" "dUVDSVFENHZvZUZqSEdMMzllcGVXVlRpYnF6UWdQTFYzWmlmbHYzMEdkb3ZqTGhVd0loQU4xVGZNOFNxdlBiCjJtelVzL2pITDJQMjl1U1B1bHd3b3lOQ052dFk3a1VKQWlFQTI5YUFMYzYzRjVrSW9G" "YVM3K2JjNDhyblVaS0cKSlh4cHliWWRmcHdDbWRNQ0lRREhGbnFHcW53c3IrOWpSbEk5enE3S2RUVFJsSmhHcFZtYU5jM1Blc2VhUVFJaApBTzl6UUUralBYK2pXbGhpTWMzZnM5amNiVWJKMWpTUDYv" - "aDBXd3Iyb1dJRwotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==") + "aDBXd3Iyb1dJRwotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQ==" +) class TestPFSenseCertModule(TestPFSenseModule): - module = pfsense_cert def __init__(self, *args, **kwargs): super(TestPFSenseCertModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_cert_config.xml' + self.config_file = "pfsense_cert_config.xml" self.pfmodule = pfsense_cert.PFSenseCertModule @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ + """return target elt from XML""" root_elt = self.xml_result.getroot() - result = root_elt.findall("cert[descr='{0}']".format(obj['name'])) + result = root_elt.findall("cert[descr='{0}']".format(obj["name"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple certs for name {0}.'.format(obj['name'])) + self.fail("Found multiple certs for name {0}.".format(obj["name"])) else: return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ + """check XML definition of target elt""" - self.check_param_equal(obj, target_elt, 'name', xml_field='descr') - self.check_param_equal(obj, target_elt, 'certificate', xml_field='crt') - if 'key' in obj: - self.check_param_equal_or_present(obj, target_elt, 'prv') + self.check_param_equal(obj, target_elt, "name", xml_field="descr") + self.check_param_equal(obj, target_elt, "certificate", xml_field="crt") + if "key" in obj: + self.check_param_equal_or_present(obj, target_elt, "prv") ############## # tests # def test_cert_create(self): - """ test creation of a new cert """ - obj = dict(name='cert1', ca='testdel') + """test creation of a new cert""" + obj = dict(name="cert1", ca="testdel") self.do_module_test(obj, command="create cert 'cert1'") def test_cert_import(self): - """ test import of a new cert """ - obj = dict(name='cert1', method='import', certificate=WEB_CRT, key=RSA_KEY) + """test import of a new cert""" + obj = dict(name="cert1", method="import", certificate=WEB_CRT, key=RSA_KEY) self.do_module_test(obj, command="create cert 'cert1'") def test_cert_delete(self): - """ test deletion of a cert """ - obj = dict(name='webConfigurator default (62083679cd3d4)') - self.do_module_test(obj, command="delete cert 'webConfigurator default (62083679cd3d4)'", delete=True) + """test deletion of a cert""" + obj = dict(name="webConfigurator default (62083679cd3d4)") + self.do_module_test( + obj, + command="delete cert 'webConfigurator default (62083679cd3d4)'", + delete=True, + ) def test_cert_update_noop(self): - """ test not updating a cert """ - obj = dict(name='webConfigurator default (62083679cd3d4)', method='import', certificate=WEB_CRT) + """test not updating a cert""" + obj = dict( + name="webConfigurator default (62083679cd3d4)", + method="import", + certificate=WEB_CRT, + ) self.do_module_test(obj, changed=False) ############## # misc # def test_add_invalid_key(self): - """ test adding an invalid key """ - key = 'blah' - obj = dict(name='invalid', method='import', certificate=WEB_CRT, key=key) - msg = 'Could not recognize key format: %s' % (key) + """test adding an invalid key""" + key = "blah" + obj = dict(name="invalid", method="import", certificate=WEB_CRT, key=key) + msg = "Could not recognize key format: %s" % (key) self.do_module_test(obj, failed=True, msg=msg) diff --git a/tests/unit/plugins/modules/test_pfsense_dhcp_server.py b/tests/unit/plugins/modules/test_pfsense_dhcp_server.py index 698c86c9..ebc2d5c0 100644 --- a/tests/unit/plugins/modules/test_pfsense_dhcp_server.py +++ b/tests/unit/plugins/modules/test_pfsense_dhcp_server.py @@ -1,7 +1,8 @@ # Copyright: (c) 2024, David Rosado # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,144 +12,173 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_dhcp_server -from ansible_collections.pfsensible.core.plugins.modules.pfsense_dhcp_server import PFSenseDHCPServerModule +from ansible_collections.pfsensible.core.plugins.modules.pfsense_dhcp_server import ( + PFSenseDHCPServerModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseDHCPServerModule(TestPFSenseModule): - module = pfsense_dhcp_server def __init__(self, *args, **kwargs): super(TestPFSenseDHCPServerModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_dhcp_server_config.xml' + self.config_file = "pfsense_dhcp_server_config.xml" self.pfmodule = PFSenseDHCPServerModule def check_target_elt(self, obj, target_elt, target_idx=-1): - """ test the xml definition """ + """test the xml definition""" # self.check_param_equal(obj, target_elt, 'interface') - self.check_param_bool(obj, target_elt, 'enable') - self.check_param_equal(obj, target_elt, 'range_from', xml_field='range/from') - self.check_param_equal(obj, target_elt, 'range_to', xml_field='range/to') - self.check_param_equal(obj, target_elt, 'failover_peerip') - self.check_param_equal(obj, target_elt, 'defaultleasetime') - self.check_param_equal(obj, target_elt, 'maxleasetime') - self.check_param_equal(obj, target_elt, 'netmask') - self.check_param_equal(obj, target_elt, 'gateway') - self.check_param_equal(obj, target_elt, 'domain') - self.check_param_equal(obj, target_elt, 'domainsearchlist') - self.check_param_equal(obj, target_elt, 'ddnsdomain') - self.check_param_equal(obj, target_elt, 'ddnsdomainprimary') - self.check_param_equal(obj, target_elt, 'ddnsdomainkeyname') - self.check_param_equal(obj, target_elt, 'ddnsdomainkeyalgorithm', default='hmac-md5') - self.check_param_equal(obj, target_elt, 'ddnsdomainkey') - self.check_param_equal(obj, target_elt, 'mac_allow') - self.check_param_equal(obj, target_elt, 'mac_deny') - self.check_param_equal(obj, target_elt, 'tftp') - self.check_param_equal(obj, target_elt, 'ldap') - self.check_param_equal(obj, target_elt, 'nextserver') - self.check_param_equal(obj, target_elt, 'filename') - self.check_param_equal(obj, target_elt, 'filename32') - self.check_param_equal(obj, target_elt, 'filename64') - self.check_param_equal(obj, target_elt, 'rootpath') - self.check_param_equal(obj, target_elt, 'numberoptions') + self.check_param_bool(obj, target_elt, "enable") + self.check_param_equal(obj, target_elt, "range_from", xml_field="range/from") + self.check_param_equal(obj, target_elt, "range_to", xml_field="range/to") + self.check_param_equal(obj, target_elt, "failover_peerip") + self.check_param_equal(obj, target_elt, "defaultleasetime") + self.check_param_equal(obj, target_elt, "maxleasetime") + self.check_param_equal(obj, target_elt, "netmask") + self.check_param_equal(obj, target_elt, "gateway") + self.check_param_equal(obj, target_elt, "domain") + self.check_param_equal(obj, target_elt, "domainsearchlist") + self.check_param_equal(obj, target_elt, "ddnsdomain") + self.check_param_equal(obj, target_elt, "ddnsdomainprimary") + self.check_param_equal(obj, target_elt, "ddnsdomainkeyname") + self.check_param_equal( + obj, target_elt, "ddnsdomainkeyalgorithm", default="hmac-md5" + ) + self.check_param_equal(obj, target_elt, "ddnsdomainkey") + self.check_param_equal(obj, target_elt, "mac_allow") + self.check_param_equal(obj, target_elt, "mac_deny") + self.check_param_equal(obj, target_elt, "tftp") + self.check_param_equal(obj, target_elt, "ldap") + self.check_param_equal(obj, target_elt, "nextserver") + self.check_param_equal(obj, target_elt, "filename") + self.check_param_equal(obj, target_elt, "filename32") + self.check_param_equal(obj, target_elt, "filename64") + self.check_param_equal(obj, target_elt, "rootpath") + self.check_param_equal(obj, target_elt, "numberoptions") def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - root_elt = self.assert_find_xml_elt(self.xml_result, 'dhcpd') - return root_elt.find(obj['interface']) + """get the generated xml definition""" + root_elt = self.assert_find_xml_elt(self.xml_result, "dhcpd") + return root_elt.find(obj["interface"]) ############## # tests # def test_dhcp_server_create(self): - """ test creation of a new DHCP server """ + """test creation of a new DHCP server""" obj = dict( - interface='opt2', + interface="opt2", enable=True, - range_from='172.16.0.100', - range_to='172.16.0.199', + range_from="172.16.0.100", + range_to="172.16.0.199", defaultleasetime=86400, maxleasetime=172800, - domain='opt2.example.com' + domain="opt2.example.com", ) - command_as_list = ["create dhcp_server 'opt2', enable=True, range_from='172.16.0.100', ", - "range_to='172.16.0.199', failover_peerip='', defaultleasetime='86400', ", - "maxleasetime='172800', netmask='', gateway='', domain='opt2.example.com', ", - "domainsearchlist='', ddnsdomain='', ddnsdomainprimary='', ddnsdomainkeyname='', ", - "ddnsdomainkeyalgorithm='hmac-md5', ddnsdomainkey='', mac_allow='', mac_deny='', ", - "ddnsclientupdates='allow', tftp='', ldap='', nextserver='', filename='', filename32='', ", - "filename64='', rootpath='', numberoptions=''"] + command_as_list = [ + "create dhcp_server 'opt2', enable=True, range_from='172.16.0.100', ", + "range_to='172.16.0.199', failover_peerip='', defaultleasetime='86400', ", + "maxleasetime='172800', netmask='', gateway='', domain='opt2.example.com', ", + "domainsearchlist='', ddnsdomain='', ddnsdomainprimary='', ddnsdomainkeyname='', ", + "ddnsdomainkeyalgorithm='hmac-md5', ddnsdomainkey='', mac_allow='', mac_deny='', ", + "ddnsclientupdates='allow', tftp='', ldap='', nextserver='', filename='', filename32='', ", + "filename64='', rootpath='', numberoptions=''", + ] command = "".join(command_as_list) self.do_module_test(obj, command=command) def test_dhcp_server_update(self): - """ test updating an existing DHCP server """ + """test updating an existing DHCP server""" obj = dict( - interface='lan', + interface="lan", enable=True, - range_from='192.168.1.50', - range_to='192.168.1.150', - domain='updated.example.com' + range_from="192.168.1.50", + range_to="192.168.1.150", + domain="updated.example.com", ) - command_as_list = ["update dhcp_server 'lan' set , range_from='192.168.1.50', range_to='192.168.1.150', ", - "defaultleasetime='', maxleasetime='', domain='updated.example.com'"] + command_as_list = [ + "update dhcp_server 'lan' set , range_from='192.168.1.50', range_to='192.168.1.150', ", + "defaultleasetime='', maxleasetime='', domain='updated.example.com'", + ] command = "".join(command_as_list) self.do_module_test(obj, command=command) def test_dhcp_server_update_disable_denyunknown(self): - """ test disabling denyunknown from an existing DHCP server """ + """test disabling denyunknown from an existing DHCP server""" obj = dict( - interface='opt1', + interface="opt1", enable=True, - range_from='10.0.0.100', - range_to='10.0.0.199', - denyunknown='disabled', + range_from="10.0.0.100", + range_to="10.0.0.199", + denyunknown="disabled", ) - command_as_list = ["update dhcp_server 'opt1' set , ", - "defaultleasetime='', maxleasetime='', domain='', denyunknown=none"] + command_as_list = [ + "update dhcp_server 'opt1' set , ", + "defaultleasetime='', maxleasetime='', domain='', denyunknown=none", + ] command = "".join(command_as_list) self.do_module_test(obj, command=command) def test_dhcp_server_delete(self): - """ test deletion of a DHCP server """ - obj = dict(interface='opt1', state='absent') + """test deletion of a DHCP server""" + obj = dict(interface="opt1", state="absent") command = "delete dhcp_server 'opt1'" self.do_module_test(obj, command=command, delete=True) def test_dhcp_server_create_invalid_interface(self): - """ test creation with an invalid interface """ - obj = dict(interface='invalid_interface', enable=True, range_from='192.168.1.100', range_to='192.168.1.200') - self.do_module_test(obj, failed=True, msg='The specified interface invalid_interface is not a valid logical interface or cannot be mapped to one') + """test creation with an invalid interface""" + obj = dict( + interface="invalid_interface", + enable=True, + range_from="192.168.1.100", + range_to="192.168.1.200", + ) + self.do_module_test( + obj, + failed=True, + msg="The specified interface invalid_interface is not a valid logical interface or cannot be mapped to one", + ) def test_dhcp_server_create_invalid_range(self): - """ test creation with an invalid IP range """ - interface = 'lan' - obj = dict(interface=interface, enable=True, range_from='192.168.1.200', range_to='192.168.1.100') - self.do_module_test(obj, failed=True, msg=f'The interface {interface} must have a valid IP range pool') + """test creation with an invalid IP range""" + interface = "lan" + obj = dict( + interface=interface, + enable=True, + range_from="192.168.1.200", + range_to="192.168.1.100", + ) + self.do_module_test( + obj, + failed=True, + msg=f"The interface {interface} must have a valid IP range pool", + ) def test_dhcp_server_create_with_options(self): - """ test creation with additional DHCP options """ + """test creation with additional DHCP options""" obj = dict( - interface='opt2', + interface="opt2", enable=True, - range_from='172.16.0.50', - range_to='172.16.0.150', + range_from="172.16.0.50", + range_to="172.16.0.150", defaultleasetime=43200, maxleasetime=86400, - domain='opt1.example.com', - ddnsdomain='ddns.example.com', - ddnsdomainprimary='172.16.0.60', - tftp='172.16.0.63', + domain="opt1.example.com", + ddnsdomain="ddns.example.com", + ddnsdomainprimary="172.16.0.60", + tftp="172.16.0.63", disablepingcheck=True, - winsserver=['172.16.0.80', '172.16.0.90'] + winsserver=["172.16.0.80", "172.16.0.90"], ) - command_as_list = ["create dhcp_server 'opt2', enable=True, range_from='172.16.0.50', ", - "range_to='172.16.0.150', failover_peerip='', defaultleasetime='43200', ", - "maxleasetime='86400', netmask='', gateway='', domain='opt1.example.com', ", - "domainsearchlist='', ddnsdomain='ddns.example.com', ddnsdomainprimary='172.16.0.60', ", - "ddnsdomainkeyname='', ddnsdomainkeyalgorithm='hmac-md5', ddnsdomainkey='', ", - "mac_allow='', mac_deny='', ddnsclientupdates='allow', tftp='172.16.0.63', ldap='', ", - "nextserver='', filename='', filename32='', filename64='', rootpath='', numberoptions=''"] + command_as_list = [ + "create dhcp_server 'opt2', enable=True, range_from='172.16.0.50', ", + "range_to='172.16.0.150', failover_peerip='', defaultleasetime='43200', ", + "maxleasetime='86400', netmask='', gateway='', domain='opt1.example.com', ", + "domainsearchlist='', ddnsdomain='ddns.example.com', ddnsdomainprimary='172.16.0.60', ", + "ddnsdomainkeyname='', ddnsdomainkeyalgorithm='hmac-md5', ddnsdomainkey='', ", + "mac_allow='', mac_deny='', ddnsclientupdates='allow', tftp='172.16.0.63', ldap='', ", + "nextserver='', filename='', filename32='', filename64='', rootpath='', numberoptions=''", + ] command = "".join(command_as_list) self.do_module_test(obj, command=command) diff --git a/tests/unit/plugins/modules/test_pfsense_dhcp_static.py b/tests/unit/plugins/modules/test_pfsense_dhcp_static.py index d433b7a3..31b4d99e 100644 --- a/tests/unit/plugins/modules/test_pfsense_dhcp_static.py +++ b/tests/unit/plugins/modules/test_pfsense_dhcp_static.py @@ -1,7 +1,8 @@ # Copyright: (c) 2023 Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,62 +12,91 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_dhcp_static -from ansible_collections.pfsensible.core.plugins.modules.pfsense_dhcp_static import PFSenseDHCPStaticModule +from ansible_collections.pfsensible.core.plugins.modules.pfsense_dhcp_static import ( + PFSenseDHCPStaticModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseDHCPStaticModule(TestPFSenseModule): - module = pfsense_dhcp_static def __init__(self, *args, **kwargs): super(TestPFSenseDHCPStaticModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_dhcp_static_config.xml' + self.config_file = "pfsense_dhcp_static_config.xml" self.pfmodule = PFSenseDHCPStaticModule def check_target_elt(self, obj, target_elt, target_idx=-1): - """ test the xml definition """ + """test the xml definition""" # checking destination address and ports - self.check_param_equal(obj, target_elt, 'name', xml_field='cid') - self.check_param_equal(obj, target_elt, 'macaddr', xml_field='mac') + self.check_param_equal(obj, target_elt, "name", xml_field="cid") + self.check_param_equal(obj, target_elt, "macaddr", xml_field="mac") # Forced options - for option in ['ipaddr', 'hostname', 'descr', 'filename', - 'rootpath', 'defaultleasetime', 'maxleasetime', - 'gateway', 'domain', 'domainsearchlist', - 'ddnsdomain', 'ddnsdomainprimary', 'ddnsdomainsecondary', - 'ddnsdomainkeyname', 'ddnsdomainkeyalgorithm', 'ddnsdomainkey', - 'tftp', 'ldap', 'nextserver', 'filename32', 'filename64', - 'filename32arm', 'filename64arm', 'uefihttpboot', 'numberoptions']: + for option in [ + "ipaddr", + "hostname", + "descr", + "filename", + "rootpath", + "defaultleasetime", + "maxleasetime", + "gateway", + "domain", + "domainsearchlist", + "ddnsdomain", + "ddnsdomainprimary", + "ddnsdomainsecondary", + "ddnsdomainkeyname", + "ddnsdomainkeyalgorithm", + "ddnsdomainkey", + "tftp", + "ldap", + "nextserver", + "filename32", + "filename64", + "filename32arm", + "filename64arm", + "uefihttpboot", + "numberoptions", + ]: self.check_param_equal_or_present(obj, target_elt, option) # Non-forced options - for option in ['winsserver', 'dnsserver', 'ntpserver']: + for option in ["winsserver", "dnsserver", "ntpserver"]: self.check_param_equal(obj, target_elt, option) # Defaulted options - self.check_param_equal(obj, target_elt, 'ddnsdomainkeyalgorithm', default='hmac-md5') + self.check_param_equal( + obj, target_elt, "ddnsdomainkeyalgorithm", default="hmac-md5" + ) def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - dhcpd_elt = self.assert_find_xml_elt(self.xml_result, 'dhcpd') + """get the generated xml definition""" + dhcpd_elt = self.assert_find_xml_elt(self.xml_result, "dhcpd") root_elt = None for e in dhcpd_elt: - if 'netif' not in obj or (module_result is not None and e.tag == module_result['netif']): - if e.find('enable') is not None: + if "netif" not in obj or ( + module_result is not None and e.tag == module_result["netif"] + ): + if e.find("enable") is not None: root_elt = e break result = [] if root_elt is not None: - if 'name' in obj and 'macaddr' in obj: - result = root_elt.findall("staticmap[cid='{0}'][mac='{1}']".format(obj['name'], obj['macaddr'])) - elif 'name' in obj: - result = root_elt.findall("staticmap[cid='{0}']".format(obj['name'])) + if "name" in obj and "macaddr" in obj: + result = root_elt.findall( + "staticmap[cid='{0}'][mac='{1}']".format( + obj["name"], obj["macaddr"] + ) + ) + elif "name" in obj: + result = root_elt.findall("staticmap[cid='{0}']".format(obj["name"])) else: - result = root_elt.findall("staticmap[mac='{0}']".format(obj['macaddr'])) + result = root_elt.findall("staticmap[mac='{0}']".format(obj["macaddr"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple static maps for cid {0}.'.format(obj['name'])) + self.fail("Found multiple static maps for cid {0}.".format(obj["name"])) else: return None @@ -74,67 +104,106 @@ def get_target_elt(self, obj, absent=False, module_result=None): # tests # def test_dhcp_static_create(self): - """ test """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.0.0.101', netif='opt1') - command = ( - "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.0.0.101'" + """test""" + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ac", + ipaddr="10.0.0.101", + netif="opt1", ) + command = "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.0.0.101'" self.do_module_test(obj, command=command) def test_dhcp_static_create_empty(self): - """ test """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.10.0.101', netif='opt2') - command = ( - "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.10.0.101'" + """test""" + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ac", + ipaddr="10.10.0.101", + netif="opt2", ) + command = "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.10.0.101'" self.do_module_test(obj, command=command) def test_dhcp_static_create_display(self): - """ test create with netif display name """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.0.0.101', netif='pub') - command = ( - "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.0.0.101'" + """test create with netif display name""" + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ac", + ipaddr="10.0.0.101", + netif="pub", ) + command = "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ac', ipaddr='10.0.0.101'" self.do_module_test(obj, command=command) def test_dhcp_static_create_arp_table_static_entry(self): - """ test create with arp_table_static_entry """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ab', ipaddr='10.0.0.101', netif='opt1', arp_table_static_entry=True) - command = ( - "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ab', ipaddr='10.0.0.101', arp_table_static_entry=True" + """test create with arp_table_static_entry""" + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ab", + ipaddr="10.0.0.101", + netif="opt1", + arp_table_static_entry=True, ) + command = "create dhcp_static 'test_entry', macaddr='ab:ab:ab:ab:ab:ab', ipaddr='10.0.0.101', arp_table_static_entry=True" self.do_module_test(obj, command=command) def test_dhcp_static_create_wrong_subnet(self): - """ test create with IP address in the wrong subnet """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ab', ipaddr='1.2.3.4', netif='opt1') - self.do_module_test(obj, failed=True, msg='The IP address must lie in the opt1 subnet.') + """test create with IP address in the wrong subnet""" + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ab", + ipaddr="1.2.3.4", + netif="opt1", + ) + self.do_module_test( + obj, failed=True, msg="The IP address must lie in the opt1 subnet." + ) def test_dhcp_static_create_no_netif(self): - """ test create with no netif """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ab', ipaddr='1.2.3.4') - self.do_module_test(obj, failed=True, msg='Multiple DHCP servers enabled and no netif specified') + """test create with no netif""" + obj = dict(name="test_entry", macaddr="ab:ab:ab:ab:ab:ab", ipaddr="1.2.3.4") + self.do_module_test( + obj, failed=True, msg="Multiple DHCP servers enabled and no netif specified" + ) def test_dhcp_static_create_ifgroup(self): - """ test create with interface group """ - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ab', ipaddr='1.2.3.4', netif='IFGROUP1') - self.do_module_test(obj, failed=True, msg='DHCP cannot be configured for interface groups') + """test create with interface group""" + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ab", + ipaddr="1.2.3.4", + netif="IFGROUP1", + ) + self.do_module_test( + obj, failed=True, msg="DHCP cannot be configured for interface groups" + ) def test_dhcp_static_create_invalid_macaddr(self): - """ test create with invalid macaddr """ - msg = 'A valid MAC address must be specified.' - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:ab:ab', ipaddr='10.10.0.101', netif='opt2') + """test create with invalid macaddr""" + msg = "A valid MAC address must be specified." + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:ab:ab", + ipaddr="10.10.0.101", + netif="opt2", + ) self.do_module_test(obj, failed=True, msg=msg) - obj = dict(name='test_entry', macaddr='ab:ab:ab:ab:ab:hh', ipaddr='10.10.0.101', netif='opt2') + obj = dict( + name="test_entry", + macaddr="ab:ab:ab:ab:ab:hh", + ipaddr="10.10.0.101", + netif="opt2", + ) self.do_module_test(obj, failed=True, msg=msg) def test_dhcp_static_delete_macaddr(self): - """ test """ - obj = dict(macaddr='ab:ab:ab:ab:ab:ab', netif='opt1', state='absent') + """test""" + obj = dict(macaddr="ab:ab:ab:ab:ab:ab", netif="opt1", state="absent") command = "delete dhcp_static ''" def test_dhcp_static_delete_name(self): - """ test """ - obj = dict(name='dhcphostid', netif='opt1', state='absent') + """test""" + obj = dict(name="dhcphostid", netif="opt1", state="absent") command = "delete dhcp_static 'dhcphostid'" self.do_module_test(obj, command=command, delete=True) diff --git a/tests/unit/plugins/modules/test_pfsense_dns_resolver.py b/tests/unit/plugins/modules/test_pfsense_dns_resolver.py index a9a7e53e..696f7c2f 100644 --- a/tests/unit/plugins/modules/test_pfsense_dns_resolver.py +++ b/tests/unit/plugins/modules/test_pfsense_dns_resolver.py @@ -2,110 +2,133 @@ # Copyright: (c) 2025, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type from ansible_collections.pfsensible.core.plugins.modules import pfsense_dns_resolver -from ansible_collections.pfsensible.core.plugins.modules.pfsense_dns_resolver import PFSenseDNSResolverModule +from ansible_collections.pfsensible.core.plugins.modules.pfsense_dns_resolver import ( + PFSenseDNSResolverModule, +) from .pfsense_module import TestPFSenseModule -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) class TestPFSenseDNSResolverModule(TestPFSenseModule): - module = pfsense_dns_resolver def __init__(self, *args, **kwargs): super(TestPFSenseDNSResolverModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_dns_resolver_config_full.xml' + self.config_file = "pfsense_dns_resolver_config_full.xml" self.pfmodule = PFSenseDNSResolverModule def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseDNSResolverModule, self).setUp() - self.mock_php = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php') + self.mock_php = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php" + ) self.php = self.mock_php.start() - self.php.return_value = {'wan': 'WAN', 'lan': 'LAN', '_llocwan': 'WAN IPv6 Link-Local', '_lloclan': 'LAN IPv6 Link-Local', 'lo0': 'Localhost'} + self.php.return_value = { + "wan": "WAN", + "lan": "LAN", + "_llocwan": "WAN IPv6 Link-Local", + "_lloclan": "LAN IPv6 Link-Local", + "lo0": "Localhost", + } def check_target_elt(self, obj, target_elt, target_idx=-1): - """ test the xml definition """ - self.check_param_equal(obj, target_elt, 'port') - self.check_param_bool(obj, target_elt, 'enablessl') - self.check_param_equal(obj, target_elt, 'sslcert') - self.check_param_equal(obj, target_elt, 'tlsport') + """test the xml definition""" + self.check_param_equal(obj, target_elt, "port") + self.check_param_bool(obj, target_elt, "enablessl") + self.check_param_equal(obj, target_elt, "sslcert") + self.check_param_equal(obj, target_elt, "tlsport") # TODO - figure out how these parameters work # self.check_param_equal(obj, target_elt, 'active_interface') # self.check_param_equal(obj, target_elt, 'outgoing_interface') # self.check_param_equal(obj, target_elt, 'system_domain_local_zone_type') - self.check_param_bool(obj, target_elt, 'dnssec', default=True) - self.check_param_bool(obj, target_elt, 'forwarding') - self.check_param_bool(obj, target_elt, 'forward_tls_upstream') - self.check_param_bool(obj, target_elt, 'regdhcp') - self.check_param_bool(obj, target_elt, 'regdhcpstatic') - self.check_param_bool(obj, target_elt, 'regovpnclients') - self.check_param_equal(obj, target_elt, 'custom_options') - self.check_param_equal(obj, target_elt, 'hosts') - self.check_param_equal(obj, target_elt, 'domainoverrides') - self.check_param_bool(obj, target_elt, 'hideidentity', default=True) - self.check_param_bool(obj, target_elt, 'hideversions', default=True) - self.check_param_bool(obj, target_elt, 'prefetch') - self.check_param_bool(obj, target_elt, 'prefetchkey') - self.check_param_bool(obj, target_elt, 'dnssecstripped', default=True) - self.check_param_equal(obj, target_elt, 'msgcachesize', default=4) - self.check_param_equal(obj, target_elt, 'outgoing_num_tcp', default=10) - self.check_param_equal(obj, target_elt, 'incoming_num_tcp', default=10) - self.check_param_equal(obj, target_elt, 'edns_buffer_size', default="auto") - self.check_param_equal(obj, target_elt, 'num_queries_per_thread', default=512) - self.check_param_equal(obj, target_elt, 'jostle_timeout', default=200) - self.check_param_equal(obj, target_elt, 'cache_max_ttl', default=86400) - self.check_param_equal(obj, target_elt, 'cache_min_ttl', default=0) - self.check_param_equal(obj, target_elt, 'infra_host_ttl', default=900) - self.check_param_equal(obj, target_elt, 'infra_cache_numhosts', default=10000) - self.check_param_equal(obj, target_elt, 'unwanted_reply_threshold', default="disabled") - self.check_param_equal(obj, target_elt, 'log_verbosity', default=1) + self.check_param_bool(obj, target_elt, "dnssec", default=True) + self.check_param_bool(obj, target_elt, "forwarding") + self.check_param_bool(obj, target_elt, "forward_tls_upstream") + self.check_param_bool(obj, target_elt, "regdhcp") + self.check_param_bool(obj, target_elt, "regdhcpstatic") + self.check_param_bool(obj, target_elt, "regovpnclients") + self.check_param_equal(obj, target_elt, "custom_options") + self.check_param_equal(obj, target_elt, "hosts") + self.check_param_equal(obj, target_elt, "domainoverrides") + self.check_param_bool(obj, target_elt, "hideidentity", default=True) + self.check_param_bool(obj, target_elt, "hideversions", default=True) + self.check_param_bool(obj, target_elt, "prefetch") + self.check_param_bool(obj, target_elt, "prefetchkey") + self.check_param_bool(obj, target_elt, "dnssecstripped", default=True) + self.check_param_equal(obj, target_elt, "msgcachesize", default=4) + self.check_param_equal(obj, target_elt, "outgoing_num_tcp", default=10) + self.check_param_equal(obj, target_elt, "incoming_num_tcp", default=10) + self.check_param_equal(obj, target_elt, "edns_buffer_size", default="auto") + self.check_param_equal(obj, target_elt, "num_queries_per_thread", default=512) + self.check_param_equal(obj, target_elt, "jostle_timeout", default=200) + self.check_param_equal(obj, target_elt, "cache_max_ttl", default=86400) + self.check_param_equal(obj, target_elt, "cache_min_ttl", default=0) + self.check_param_equal(obj, target_elt, "infra_host_ttl", default=900) + self.check_param_equal(obj, target_elt, "infra_cache_numhosts", default=10000) + self.check_param_equal( + obj, target_elt, "unwanted_reply_threshold", default="disabled" + ) + self.check_param_equal(obj, target_elt, "log_verbosity", default=1) def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - return self.assert_find_xml_elt(self.xml_result, 'unbound') + """get the generated xml definition""" + return self.assert_find_xml_elt(self.xml_result, "unbound") ############## # tests # def test_dns_resolver_init(self): - """ test init of the DNS Resolver """ + """test init of the DNS Resolver""" obj = dict() - command_as_list = ["update dns_resolver pfsense_dns_resolver set active_interface='all', " - "outgoing_interface='all', system_domain_local_zone_type='transparent', " - "msgcachesize='4', outgoing_num_tcp='10', incoming_num_tcp='10', " - "edns_buffer_size='auto', num_queries_per_thread='512', jostle_timeout='200', " - "cache_max_ttl='86400', cache_min_ttl='0', infra_host_ttl='900', " - "infra_cache_numhosts='10000', unwanted_reply_threshold='disabled', " - "log_verbosity='1'"] + command_as_list = [ + "update dns_resolver pfsense_dns_resolver set active_interface='all', " + "outgoing_interface='all', system_domain_local_zone_type='transparent', " + "msgcachesize='4', outgoing_num_tcp='10', incoming_num_tcp='10', " + "edns_buffer_size='auto', num_queries_per_thread='512', jostle_timeout='200', " + "cache_max_ttl='86400', cache_min_ttl='0', infra_host_ttl='900', " + "infra_cache_numhosts='10000', unwanted_reply_threshold='disabled', " + "log_verbosity='1'" + ] command = "".join(command_as_list) - self.config_file = 'pfsense_dns_resolver_config_init.xml' + self.config_file = "pfsense_dns_resolver_config_init.xml" self.do_module_test(obj, command=command) def test_dns_resolver_change(self): - """ test initialization of the DNS Resolver """ - obj = dict( - active_interface=['lan', 'lo0'], - outgoing_interface=['wan'] - ) - command_as_list = ["update dns_resolver pfsense_dns_resolver set active_interface='lan,lo0', outgoing_interface='wan'"] + """test initialization of the DNS Resolver""" + obj = dict(active_interface=["lan", "lo0"], outgoing_interface=["wan"]) + command_as_list = [ + "update dns_resolver pfsense_dns_resolver set active_interface='lan,lo0', outgoing_interface='wan'" + ] command = "".join(command_as_list) self.do_module_test(obj, command=command) def test_dns_resolver_noop(self): - """ test noop of the DNS Resolver """ + """test noop of the DNS Resolver""" obj = dict() self.do_module_test(obj, changed=False) def test_dns_resolver_domainoverrides_forward_tls_upstream(self): - """ test initialization of the DNS Resolver """ + """test initialization of the DNS Resolver""" obj = dict( - domainoverrides=[dict(domain="test.example.com", descr="A description", forward_tls_upstream=False, ip="10.0.0.3", tls_hostname='')] + domainoverrides=[ + dict( + domain="test.example.com", + descr="A description", + forward_tls_upstream=False, + ip="10.0.0.3", + tls_hostname="", + ) + ] ) command_as_list = ["update dns_resolver pfsense_dns_resolver set "] command = "".join(command_as_list) @@ -140,4 +163,6 @@ def test_dns_resolver_domainoverrides_forward_tls_upstream(self): """ # noqa: E101,W191 - self.do_module_test(obj, command=command, expected_elt_string=expected_elt_string) + self.do_module_test( + obj, command=command, expected_elt_string=expected_elt_string + ) diff --git a/tests/unit/plugins/modules/test_pfsense_gateway.py b/tests/unit/plugins/modules/test_pfsense_gateway.py index 916077a4..59a5ae06 100644 --- a/tests/unit/plugins/modules/test_pfsense_gateway.py +++ b/tests/unit/plugins/modules/test_pfsense_gateway.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,43 +12,52 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_gateway -from ansible_collections.pfsensible.core.plugins.module_utils.gateway import PFSenseGatewayModule +from ansible_collections.pfsensible.core.plugins.module_utils.gateway import ( + PFSenseGatewayModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseGatewayModule(TestPFSenseModule): - module = pfsense_gateway def __init__(self, *args, **kwargs): super(TestPFSenseGatewayModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_gateway_config.xml' + self.config_file = "pfsense_gateway_config.xml" self.pfmodule = PFSenseGatewayModule def check_target_elt(self, obj, target_elt): - """ test the xml definition """ - - self.check_param_equal_or_not_find(obj, target_elt, 'monitor') - - self.check_param_equal_or_not_find(obj, target_elt, 'disabled', empty=True) - self.check_param_equal_or_not_find(obj, target_elt, 'monitor_disable', empty=True) - self.check_param_equal_or_not_find(obj, target_elt, 'action_disable', empty=True) - self.check_param_equal_or_not_find(obj, target_elt, 'force_down', empty=True) - self.check_param_equal_or_not_find(obj, target_elt, 'nonlocalgateway', empty=True) - - self.check_value_equal(target_elt, 'interface', self.unalias_interface(obj['interface'])) - self.check_param_equal(obj, target_elt, 'descr') - self.check_param_equal(obj, target_elt, 'weight', '1') - self.check_param_equal(obj, target_elt, 'gateway') - self.check_param_equal(obj, target_elt, 'ipprotocol', 'inet') + """test the xml definition""" + + self.check_param_equal_or_not_find(obj, target_elt, "monitor") + + self.check_param_equal_or_not_find(obj, target_elt, "disabled", empty=True) + self.check_param_equal_or_not_find( + obj, target_elt, "monitor_disable", empty=True + ) + self.check_param_equal_or_not_find( + obj, target_elt, "action_disable", empty=True + ) + self.check_param_equal_or_not_find(obj, target_elt, "force_down", empty=True) + self.check_param_equal_or_not_find( + obj, target_elt, "nonlocalgateway", empty=True + ) + + self.check_value_equal( + target_elt, "interface", self.unalias_interface(obj["interface"]) + ) + self.check_param_equal(obj, target_elt, "descr") + self.check_param_equal(obj, target_elt, "weight", "1") + self.check_param_equal(obj, target_elt, "gateway") + self.check_param_equal(obj, target_elt, "ipprotocol", "inet") def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - rules_elt = self.assert_find_xml_elt(self.xml_result, 'gateways') + """get the generated xml definition""" + rules_elt = self.assert_find_xml_elt(self.xml_result, "gateways") for item in rules_elt: - name_elt = item.find('name') - if name_elt is not None and name_elt.text == obj['name']: + name_elt = item.find("name") + if name_elt is not None and name_elt.text == obj["name"]: return item return None @@ -56,169 +66,212 @@ def get_target_elt(self, obj, absent=False, module_result=None): # tests # def test_gateway_create(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='192.168.1.1') + """test""" + obj = dict(name="test_gw", interface="lan", gateway="192.168.1.1") command = "create gateway 'test_gw', interface='lan', gateway='192.168.1.1'" self.do_module_test(obj, command=command) def test_gateway_create_with_params(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='192.168.1.1', descr='a test gw', monitor='8.8.8.8', weight=10) + """test""" + obj = dict( + name="test_gw", + interface="lan", + gateway="192.168.1.1", + descr="a test gw", + monitor="8.8.8.8", + weight=10, + ) command = "create gateway 'test_gw', interface='lan', gateway='192.168.1.1', descr='a test gw', monitor='8.8.8.8', weight='10'" self.do_module_test(obj, command=command) def test_gateway_create_ipv6(self): - """ test """ - obj = dict(name='test_gw', interface='wan', ipprotocol='inet6', gateway='2001::1') + """test""" + obj = dict( + name="test_gw", interface="wan", ipprotocol="inet6", gateway="2001::1" + ) command = "create gateway 'test_gw', interface='wan', ipprotocol='inet6', gateway='2001::1'" self.do_module_test(obj, command=command) def test_gateway_create_in_vip(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='10.255.2.1') + """test""" + obj = dict(name="test_gw", interface="lan", gateway="10.255.2.1") command = "create gateway 'test_gw', interface='lan', gateway='10.255.2.1'" self.do_module_test(obj, command=command) def test_gateway_create_invalid_name(self): - """ test """ - obj = dict(name='___', interface='lan', gateway='192.168.1.1') + """test""" + obj = dict(name="___", interface="lan", gateway="192.168.1.1") msg = "The gateway name '___' must be less than 32 characters long, may not consist of only numbers, " msg += "may not consist of only underscores, and may only contain the following characters: a-z, A-Z, 0-9, _" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_interface(self): - """ test """ - obj = dict(name='test_gw', interface='lan_232', gateway='192.168.1.1') - msg = 'lan_232 is not a valid interface' + """test""" + obj = dict(name="test_gw", interface="lan_232", gateway="192.168.1.1") + msg = "lan_232 is not a valid interface" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_nonlocal(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='1.2.3.4', nonlocalgateway=True) + """test""" + obj = dict( + name="test_gw", interface="lan", gateway="1.2.3.4", nonlocalgateway=True + ) command = "create gateway 'test_gw', interface='lan', gateway='1.2.3.4', nonlocalgateway=True" self.do_module_test(obj, command=command) def test_gateway_create_invalid_ip(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='acme.dyndns.org') - msg = 'gateway must use an IPv4 address' + """test""" + obj = dict(name="test_gw", interface="lan", gateway="acme.dyndns.org") + msg = "gateway must use an IPv4 address" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_ip2(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='1.2.3.4') + """test""" + obj = dict(name="test_gw", interface="lan", gateway="1.2.3.4") msg = "The gateway address 1.2.3.4 does not lie within one of the chosen interface's subnets." self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_ip3(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='2001::1') - msg = 'gateway must use an IPv4 address' + """test""" + obj = dict(name="test_gw", interface="lan", gateway="2001::1") + msg = "gateway must use an IPv4 address" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_ip4(self): - """ test """ - obj = dict(name='test_gw', interface='vt1', gateway='192.168.1.1') - msg = 'Cannot add IPv4 Gateway Address because no IPv4 address could be found on the interface.' + """test""" + obj = dict(name="test_gw", interface="vt1", gateway="192.168.1.1") + msg = "Cannot add IPv4 Gateway Address because no IPv4 address could be found on the interface." self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_monitor(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='192.168.1.1', monitor='2001::1') - msg = 'monitor must use an IPv4 address' + """test""" + obj = dict( + name="test_gw", interface="lan", gateway="192.168.1.1", monitor="2001::1" + ) + msg = "monitor must use an IPv4 address" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_ipv6(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='2001::1', ipprotocol='inet6') + """test""" + obj = dict( + name="test_gw", interface="lan", gateway="2001::1", ipprotocol="inet6" + ) msg = "Cannot add IPv6 Gateway Address because no IPv6 address could be found on the interface." self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_ipv6_2(self): - """ test """ - obj = dict(name='test_gw', interface='wan', gateway='192.168.1.2', ipprotocol='inet6') + """test""" + obj = dict( + name="test_gw", interface="wan", gateway="192.168.1.2", ipprotocol="inet6" + ) msg = "gateway must use an IPv6 address" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_ipv6_monitor(self): - """ test """ - obj = dict(name='test_gw', interface='wan', ipprotocol='inet6', gateway='2001::1', monitor='192.168.1.1') - msg = 'monitor must use an IPv6 address' + """test""" + obj = dict( + name="test_gw", + interface="wan", + ipprotocol="inet6", + gateway="2001::1", + monitor="192.168.1.1", + ) + msg = "monitor must use an IPv6 address" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_create_invalid_weight(self): - """ test """ - obj = dict(name='test_gw', interface='lan', gateway='192.168.1.1', weight='40') - msg = 'weight must be between 1 and 30' + """test""" + obj = dict(name="test_gw", interface="lan", gateway="192.168.1.1", weight="40") + msg = "weight must be between 1 and 30" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_update_noop(self): - """ test """ - obj = dict(name='GW_WAN', interface='wan', gateway='192.168.240.1', descr='Interface wan Gateway') + """test""" + obj = dict( + name="GW_WAN", + interface="wan", + gateway="192.168.240.1", + descr="Interface wan Gateway", + ) self.do_module_test(obj, changed=False) def test_gateway_update_dynamic(self): - """ test """ - obj = dict(name='OPT3_VTIV4', interface='lan', gateway='dynamic') + """test""" + obj = dict(name="OPT3_VTIV4", interface="lan", gateway="dynamic") msg = "The gateway use 'dynamic' as a target. You can not change the interface" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_update_dynamic2(self): - """ test """ - obj = dict(name='OPT3_VTIV4', interface='lan_100', gateway='1.2.3.4') + """test""" + obj = dict(name="OPT3_VTIV4", interface="lan_100", gateway="1.2.3.4") msg = "The gateway use 'dynamic' as a target. This is read-only, so you must set gateway as dynamic too" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_update_dynamic3(self): - """ test """ - obj = dict(name='OPT3_VTIV4', interface='lan_100', gateway='dynamic', ipprotocol='inet6') + """test""" + obj = dict( + name="OPT3_VTIV4", + interface="lan_100", + gateway="dynamic", + ipprotocol="inet6", + ) msg = "The gateway use 'dynamic' as a target. You can not change ipprotocol" self.do_module_test(obj, msg=msg, failed=True) def test_gateway_update_dynamic4(self): - """ test """ - obj = dict(name='OPT3_VTIV4', interface='lan_100', gateway='dynamic', weight=2) + """test""" + obj = dict(name="OPT3_VTIV4", interface="lan_100", gateway="dynamic", weight=2) command = "update gateway 'OPT3_VTIV4' set weight='2'" self.do_module_test(obj, command=command) def test_gateway_update_interface(self): - """ test """ - obj = dict(name='GW_WAN', interface='lan', gateway='192.168.1.1', descr='Interface wan Gateway') + """test""" + obj = dict( + name="GW_WAN", + interface="lan", + gateway="192.168.1.1", + descr="Interface wan Gateway", + ) command = "update gateway 'GW_WAN' set interface='lan', gateway='192.168.1.1'" self.do_module_test(obj, command=command) def test_gateway_update_bools_and_monitor(self): - """ test """ - obj = dict(name='GW_LAN', interface='lan', gateway='192.168.1.1', descr='Interface lan Gateway') + """test""" + obj = dict( + name="GW_LAN", + interface="lan", + gateway="192.168.1.1", + descr="Interface lan Gateway", + ) command = "update gateway 'GW_LAN' set disabled=False, monitor=none, monitor_disable=False, action_disable=False, force_down=False" self.do_module_test(obj, command=command) def test_gateway_delete(self): - """ test """ - obj = dict(name='GW_WAN2') + """test""" + obj = dict(name="GW_WAN2") command = "delete gateway 'GW_WAN2'" self.do_module_test(obj, command=command, delete=True) def test_gateway_delete_static(self): - """ test """ - obj = dict(name='OPT3_VTIV4') + """test""" + obj = dict(name="OPT3_VTIV4") msg = "The gateway use 'dynamic' as a target. You can not delete it" self.do_module_test(obj, msg=msg, delete=True, failed=True) def test_gateway_delete_default(self): - """ test """ - obj = dict(name='GW_DEFAULT') + """test""" + obj = dict(name="GW_DEFAULT") msg = "The gateway is still in use. You can not delete it" self.do_module_test(obj, msg=msg, delete=True, failed=True) def test_gateway_delete_in_group(self): - """ test """ - obj = dict(name='GW_LAN') + """test""" + obj = dict(name="GW_LAN") msg = "The gateway is still in use. You can not delete it" self.do_module_test(obj, msg=msg, delete=True, failed=True) def test_gateway_delete_in_route(self): - """ test """ - obj = dict(name='GW_WAN') + """test""" + obj = dict(name="GW_WAN") msg = "The gateway is still in use. You can not delete it" self.do_module_test(obj, msg=msg, delete=True, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_haproxy_backend.py b/tests/unit/plugins/modules/test_pfsense_haproxy_backend.py index f55275b5..7d845482 100644 --- a/tests/unit/plugins/modules/test_pfsense_haproxy_backend.py +++ b/tests/unit/plugins/modules/test_pfsense_haproxy_backend.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,39 +12,41 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_haproxy_backend -from ansible_collections.pfsensible.core.plugins.module_utils.haproxy_backend import PFSenseHaproxyBackendModule +from ansible_collections.pfsensible.core.plugins.module_utils.haproxy_backend import ( + PFSenseHaproxyBackendModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseHaproxyBackendModule(TestPFSenseModule): - module = pfsense_haproxy_backend def __init__(self, *args, **kwargs): super(TestPFSenseHaproxyBackendModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_haproxy_backend_config.xml' + self.config_file = "pfsense_haproxy_backend_config.xml" self.pfmodule = PFSenseHaproxyBackendModule ############## # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated backend xml definition """ - pkgs_elt = self.assert_find_xml_elt(self.xml_result, 'installedpackages') - hap_elt = self.assert_find_xml_elt(pkgs_elt, 'haproxy') - backends_elt = self.assert_find_xml_elt(hap_elt, 'ha_pools') + """get the generated backend xml definition""" + pkgs_elt = self.assert_find_xml_elt(self.xml_result, "installedpackages") + hap_elt = self.assert_find_xml_elt(pkgs_elt, "haproxy") + backends_elt = self.assert_find_xml_elt(hap_elt, "ha_pools") for item in backends_elt: - name_elt = item.find('name') - if name_elt is not None and name_elt.text == obj['name']: + name_elt = item.find("name") + if name_elt is not None and name_elt.text == obj["name"]: return item if not absent: - self.fail('haproxy_backend ' + obj['name'] + ' not found.') + self.fail("haproxy_backend " + obj["name"] + " not found.") return None def check_target_elt(self, obj, target_elt, backend_id=100): - """ test the xml definition of backend """ + """test the xml definition of backend""" + def _check_elt(name, fname=None, default=None): if fname is None: fname = name @@ -60,70 +63,84 @@ def _check_bool_elt(name, fname=None): fname = name if obj.get(name): - self.assert_xml_elt_equal(target_elt, fname, 'yes') + self.assert_xml_elt_equal(target_elt, fname, "yes") else: self.assert_xml_elt_is_none_or_empty(target_elt, fname) - self.assert_xml_elt_equal(target_elt, 'id', str(backend_id)) + self.assert_xml_elt_equal(target_elt, "id", str(backend_id)) # checking balance - if 'balance' in obj and obj['balance'] != 'none': - self.assert_xml_elt_equal(target_elt, 'balance', obj['balance']) + if "balance" in obj and obj["balance"] != "none": + self.assert_xml_elt_equal(target_elt, "balance", obj["balance"]) else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'balance') + self.assert_xml_elt_is_none_or_empty(target_elt, "balance") # check everything else - _check_elt('balance_urilen') - _check_elt('balance_uridepth') - _check_bool_elt('balance_uriwhole') - _check_elt('connection_timeout') - _check_elt('server_timeout') - _check_elt('check_type', default='none') - _check_elt('check_frequency', 'checkinter') - _check_elt('retries') - _check_bool_elt('log_checks', 'log-health-checks') - _check_elt('httpcheck_method') - _check_elt('monitor_uri') - _check_elt('monitor_httpversion') - _check_elt('monitor_username') - _check_elt('monitor_domain') + _check_elt("balance_urilen") + _check_elt("balance_uridepth") + _check_bool_elt("balance_uriwhole") + _check_elt("connection_timeout") + _check_elt("server_timeout") + _check_elt("check_type", default="none") + _check_elt("check_frequency", "checkinter") + _check_elt("retries") + _check_bool_elt("log_checks", "log-health-checks") + _check_elt("httpcheck_method") + _check_elt("monitor_uri") + _check_elt("monitor_httpversion") + _check_elt("monitor_username") + _check_elt("monitor_domain") ############## # tests # def test_haproxy_backend_create(self): - """ test creation of a new backend """ - backend = dict(name='exchange') + """test creation of a new backend""" + backend = dict(name="exchange") command = "create haproxy_backend 'exchange', balance='none', check_type='none'" self.do_module_test(backend, command=command, backend_id=102) def test_haproxy_backend_create2(self): - """ test creation of a new backend with some parameters""" - backend = dict(name='exchange', balance='roundrobin', check_type='HTTP') - command = "create haproxy_backend 'exchange', balance='roundrobin', check_type='HTTP'" + """test creation of a new backend with some parameters""" + backend = dict(name="exchange", balance="roundrobin", check_type="HTTP") + command = ( + "create haproxy_backend 'exchange', balance='roundrobin', check_type='HTTP'" + ) self.do_module_test(backend, command=command, backend_id=102) def test_haproxy_backend_create_invalid_name(self): - """ test creation of a new backend """ - backend = dict(name='exchange test') + """test creation of a new backend""" + backend = dict(name="exchange test") msg = "The field 'name' contains invalid characters." self.do_module_test(backend, msg=msg, failed=True) def test_haproxy_backend_delete(self): - """ test deletion of a backend """ - backend = dict(name='test-backend') + """test deletion of a backend""" + backend = dict(name="test-backend") command = "delete haproxy_backend 'test-backend'" self.do_module_test(backend, delete=True, command=command) def test_haproxy_backend_update_noop(self): - """ test not updating a backend """ + """test not updating a backend""" backend = dict( - name='test-backend', balance='uri', balance_uriwhole=True, log_checks=True, check_type='SSL', check_frequency=123456, httpcheck_method='OPTIONS' + name="test-backend", + balance="uri", + balance_uriwhole=True, + log_checks=True, + check_type="SSL", + check_frequency=123456, + httpcheck_method="OPTIONS", ) self.do_module_test(backend, changed=False) def test_haproxy_backend_update_bools(self): - """ test updating bools """ - backend = dict(name='test-backend', balance='uri', check_type='SSL', check_frequency=123456, httpcheck_method='OPTIONS') + """test updating bools""" + backend = dict( + name="test-backend", + balance="uri", + check_type="SSL", + check_frequency=123456, + httpcheck_method="OPTIONS", + ) command = "update haproxy_backend 'test-backend' set balance_uriwhole=False, log_checks=False" self.do_module_test(backend, changed=True, command=command) diff --git a/tests/unit/plugins/modules/test_pfsense_haproxy_backend_server.py b/tests/unit/plugins/modules/test_pfsense_haproxy_backend_server.py index 50ac4d0b..5d5a7256 100644 --- a/tests/unit/plugins/modules/test_pfsense_haproxy_backend_server.py +++ b/tests/unit/plugins/modules/test_pfsense_haproxy_backend_server.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -10,82 +11,86 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.pfsensible.core.plugins.modules import pfsense_haproxy_backend_server -from ansible_collections.pfsensible.core.plugins.module_utils.haproxy_backend_server import PFSenseHaproxyBackendServerModule +from ansible_collections.pfsensible.core.plugins.modules import ( + pfsense_haproxy_backend_server, +) +from ansible_collections.pfsensible.core.plugins.module_utils.haproxy_backend_server import ( + PFSenseHaproxyBackendServerModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseHaproxyBackendServerModule(TestPFSenseModule): - module = pfsense_haproxy_backend_server def __init__(self, *args, **kwargs): super(TestPFSenseHaproxyBackendServerModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_haproxy_backend_server_config.xml' + self.config_file = "pfsense_haproxy_backend_server_config.xml" self.pfmodule = PFSenseHaproxyBackendServerModule ############## # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated backend server xml definition """ - pkgs_elt = self.assert_find_xml_elt(self.xml_result, 'installedpackages') - hap_elt = self.assert_find_xml_elt(pkgs_elt, 'haproxy') - backends_elt = self.assert_find_xml_elt(hap_elt, 'ha_pools') + """get the generated backend server xml definition""" + pkgs_elt = self.assert_find_xml_elt(self.xml_result, "installedpackages") + hap_elt = self.assert_find_xml_elt(pkgs_elt, "haproxy") + backends_elt = self.assert_find_xml_elt(hap_elt, "ha_pools") for item in backends_elt: - name_elt = item.find('name') - if name_elt is not None and name_elt.text == obj['backend']: + name_elt = item.find("name") + if name_elt is not None and name_elt.text == obj["backend"]: backend_elt = item break if backend_elt is None: - self.fail('haproxy backend ' + obj['backend'] + ' not found.') + self.fail("haproxy backend " + obj["backend"] + " not found.") - servers_elt = self.assert_find_xml_elt(backend_elt, 'ha_servers') + servers_elt = self.assert_find_xml_elt(backend_elt, "ha_servers") for item in servers_elt: - name_elt = item.find('name') - if name_elt is not None and name_elt.text == obj['name']: + name_elt = item.find("name") + if name_elt is not None and name_elt.text == obj["name"]: return item if not absent: - self.fail('haproxy backend server ' + obj['name'] + ' not found.') + self.fail("haproxy backend server " + obj["name"] + " not found.") return None @staticmethod def caref(descr): - """ return refid for ca """ - if descr == 'test ca': - return '5d85d3071588f' - if descr == 'test ca2': - return '5df5ec5668d9f' - return '' + """return refid for ca""" + if descr == "test ca": + return "5d85d3071588f" + if descr == "test ca2": + return "5df5ec5668d9f" + return "" @staticmethod def crlref(descr): - """ return refid for crl """ - if descr == 'test crl': - return '5df5edf6cae0f' - if descr == 'test crl2': - return '5df5ee048c106' - return '' + """return refid for crl""" + if descr == "test crl": + return "5df5edf6cae0f" + if descr == "test crl2": + return "5df5ee048c106" + return "" @staticmethod def certref(descr): - """ return refid for cert """ - if descr == 'test cert': - return '5df5ec78b3048' - if descr == 'test cert2': - return '5df5ec97dfd07' - return '' + """return refid for cert""" + if descr == "test cert": + return "5df5ec78b3048" + if descr == "test cert2": + return "5df5ec97dfd07" + return "" @staticmethod def idem(descr): - """ return value passed """ + """return value passed""" return descr def check_target_elt(self, obj, target_elt, server_id): - """ test the xml definition of server """ + """test the xml definition of server""" + def _check_elt(name, fname=None, default=None, fvalue=self.idem): if fname is None: fname = name @@ -104,45 +109,57 @@ def _check_bool_elt(name, fname=None, false_exists=False): fname = name if obj.get(name): - self.assert_xml_elt_equal(target_elt, fname, 'yes') + self.assert_xml_elt_equal(target_elt, fname, "yes") elif name in obj and false_exists: self.assert_xml_elt_is_none_or_empty(target_elt, fname) else: self.assert_not_find_xml_elt(target_elt, fname) - self.assert_xml_elt_equal(target_elt, 'id', str(server_id)) - - _check_elt('mode', fname='status', default='active') - _check_elt('forwardto') - _check_elt('address') - _check_elt('port') - _check_elt('weight') - _check_elt('verifyhost') - _check_elt('ca', fname='ssl-server-ca', fvalue=self.caref) - _check_elt('crl', fname='ssl-server-crl', fvalue=self.crlref) - _check_elt('clientcert', fname='ssl-server-clientcert', fvalue=self.certref) - _check_elt('cookie') - _check_elt('maxconn') - _check_elt('advanced') - _check_elt('istemplate') - - _check_bool_elt('ssl') - _check_bool_elt('checkssl') - _check_bool_elt('sslserververify') + self.assert_xml_elt_equal(target_elt, "id", str(server_id)) + + _check_elt("mode", fname="status", default="active") + _check_elt("forwardto") + _check_elt("address") + _check_elt("port") + _check_elt("weight") + _check_elt("verifyhost") + _check_elt("ca", fname="ssl-server-ca", fvalue=self.caref) + _check_elt("crl", fname="ssl-server-crl", fvalue=self.crlref) + _check_elt("clientcert", fname="ssl-server-clientcert", fvalue=self.certref) + _check_elt("cookie") + _check_elt("maxconn") + _check_elt("advanced") + _check_elt("istemplate") + + _check_bool_elt("ssl") + _check_bool_elt("checkssl") + _check_bool_elt("sslserververify") ############## # tests # def test_haproxy_backend_server_create(self): - """ test creation of a new backend server """ - server = dict(backend='test-backend', name='exchange', address='exchange.acme.org', port=443) + """test creation of a new backend server""" + server = dict( + backend="test-backend", + name="exchange", + address="exchange.acme.org", + port=443, + ) command = "create haproxy_backend_server 'exchange' on 'test-backend', status='active', address='exchange.acme.org', port=443" self.do_module_test(server, command=command, server_id=103) def test_haproxy_backend_server_create2(self): - """ test creation of a new backend server with some parameters""" + """test creation of a new backend server with some parameters""" server = dict( - backend='test-backend', name='exchange', address='exchange.acme.org', port=443, ssl=True, ca='test ca', clientcert='test cert', crl='test crl' + backend="test-backend", + name="exchange", + address="exchange.acme.org", + port=443, + ssl=True, + ca="test ca", + clientcert="test cert", + crl="test crl", ) command = ( "create haproxy_backend_server 'exchange' on 'test-backend', status='active', address='exchange.acme.org', port=443, " @@ -151,78 +168,130 @@ def test_haproxy_backend_server_create2(self): self.do_module_test(server, command=command, server_id=103) def test_haproxy_backend_server_create_invalid_backend(self): - """ test creation of a new backend server """ - server = dict(backend='test.backend', name='exchange', address='exchange.acme.org', port=443) + """test creation of a new backend server""" + server = dict( + backend="test.backend", + name="exchange", + address="exchange.acme.org", + port=443, + ) msg = "The backend named 'test.backend' does not exist" self.do_module_test(server, msg=msg, failed=True) def test_haproxy_backend_server_create_invalid_name(self): - """ test creation of a new backend server """ - server = dict(backend='test-backend', name='test exchange', address='exchange.acme.org', port=443) + """test creation of a new backend server""" + server = dict( + backend="test-backend", + name="test exchange", + address="exchange.acme.org", + port=443, + ) msg = "The field 'name' contains invalid characters" self.do_module_test(server, msg=msg, failed=True) def test_haproxy_backend_server_delete(self): - """ test deletion of a backend server """ - server = dict(backend='test-backend', name='exchange.acme.org') + """test deletion of a backend server""" + server = dict(backend="test-backend", name="exchange.acme.org") command = "delete haproxy_backend_server 'exchange.acme.org' on 'test-backend'" self.do_module_test(server, delete=True, command=command) def test_haproxy_backend_server_update_noop(self): - """ test not updating a backend server """ - server = dict(backend='test-backend', name='exchange.acme.org', address='exchange.acme.org', port=443) + """test not updating a backend server""" + server = dict( + backend="test-backend", + name="exchange.acme.org", + address="exchange.acme.org", + port=443, + ) self.do_module_test(server, changed=False) def test_haproxy_backend_server_update_frontend(self): - """ test updating a backend server """ - server = dict(backend='test-backend', name='exchange.acme.org', forwardto='test-frontend') + """test updating a backend server""" + server = dict( + backend="test-backend", name="exchange.acme.org", forwardto="test-frontend" + ) command = "update haproxy_backend_server 'exchange.acme.org' on 'test-backend' set forwardto='test-frontend', address=none, port=none" self.do_module_test(server, changed=True, command=command, server_id=101) def test_haproxy_backend_server_update_certs(self): - """ test updating certs """ + """test updating certs""" server = dict( - backend='test-backend', name='exchange2.acme.org', address='exchange2.acme.org', port=443, ca='test ca2', clientcert='test cert2', crl='test crl2' + backend="test-backend", + name="exchange2.acme.org", + address="exchange2.acme.org", + port=443, + ca="test ca2", + clientcert="test cert2", + crl="test crl2", ) command = "update haproxy_backend_server 'exchange2.acme.org' on 'test-backend' set ca='test ca2', crl='test crl2', clientcert='test cert2'" self.do_module_test(server, changed=True, command=command, server_id=102) def test_haproxy_backend_server_update_certs2(self): - """ test updating certs """ + """test updating certs""" server = dict( - backend='test-backend', name='exchange2.acme.org', address='exchange2.acme.org', port=443 + backend="test-backend", + name="exchange2.acme.org", + address="exchange2.acme.org", + port=443, ) command = "update haproxy_backend_server 'exchange2.acme.org' on 'test-backend' set ca=none, crl=none, clientcert=none" self.do_module_test(server, changed=True, command=command, server_id=102) def test_haproxy_backend_server_update_certs3(self): - """ test updating certs """ + """test updating certs""" server = dict( - backend='test-backend', name='exchange.acme.org', address='exchange.acme.org', port=443, ca='test ca2', clientcert='test cert2', crl='test crl2' + backend="test-backend", + name="exchange.acme.org", + address="exchange.acme.org", + port=443, + ca="test ca2", + clientcert="test cert2", + crl="test crl2", ) command = "update haproxy_backend_server 'exchange.acme.org' on 'test-backend' set ca='test ca2', crl='test crl2', clientcert='test cert2'" self.do_module_test(server, changed=True, command=command, server_id=101) def test_haproxy_backend_server_invalid_ca(self): - """ test updating certs """ - server = dict(backend='test-backend', name='exchange', address='exchange.acme.org', port=443, ca='test ca3') + """test updating certs""" + server = dict( + backend="test-backend", + name="exchange", + address="exchange.acme.org", + port=443, + ca="test ca3", + ) msg = "test ca3 is not a valid certificate authority" self.do_module_test(server, msg=msg, failed=True) def test_haproxy_backend_server_invalid_crl(self): - """ test updating certs """ - server = dict(backend='test-backend', name='exchange', address='exchange.acme.org', port=443, crl='test crl3') + """test updating certs""" + server = dict( + backend="test-backend", + name="exchange", + address="exchange.acme.org", + port=443, + crl="test crl3", + ) msg = "test crl3 is not a valid certificate revocation list" self.do_module_test(server, msg=msg, failed=True) def test_haproxy_backend_server_invalid_cert(self): - """ test updating certs """ - server = dict(backend='test-backend', name='exchange', address='exchange.acme.org', port=443, clientcert='test cert3') + """test updating certs""" + server = dict( + backend="test-backend", + name="exchange", + address="exchange.acme.org", + port=443, + clientcert="test cert3", + ) msg = "test cert3 is not a valid certificate" self.do_module_test(server, msg=msg, failed=True) def test_haproxy_backend_server_invalid_frontend(self): - """ test updating certs """ - server = dict(backend='test-backend', name='exchange', forwardto='test frontend') + """test updating certs""" + server = dict( + backend="test-backend", name="exchange", forwardto="test frontend" + ) msg = "The frontend named 'test frontend' does not exist" self.do_module_test(server, msg=msg, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_interface.py b/tests/unit/plugins/modules/test_pfsense_interface.py index f2ed9264..06eabc58 100644 --- a/tests/unit/plugins/modules/test_pfsense_interface.py +++ b/tests/unit/plugins/modules/test_pfsense_interface.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,33 +12,34 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_interface -from ansible_collections.pfsensible.core.plugins.module_utils.interface import PFSenseInterfaceModule +from ansible_collections.pfsensible.core.plugins.module_utils.interface import ( + PFSenseInterfaceModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseInterfaceModule(TestPFSenseModule): - module = pfsense_interface def __init__(self, *args, **kwargs): super(TestPFSenseInterfaceModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_interface_config.xml' + self.config_file = "pfsense_interface_config.xml" self.pfmodule = PFSenseInterfaceModule def setUp(self): - """ mocking up """ + """mocking up""" def php_mock(command): - if 'get_interface_list' in command: + if "get_interface_list" in command: interfaces = dict() - interfaces['vmx0'] = dict() - interfaces['vmx1'] = dict(descr='notuniq') - interfaces['vmx2'] = dict(descr='notuniq') - interfaces['vmx3'] = dict() - interfaces['vmx0.100'] = dict(descr='uniq') - interfaces['vmx1.1100'] = dict() + interfaces["vmx0"] = dict() + interfaces["vmx1"] = dict(descr="notuniq") + interfaces["vmx2"] = dict(descr="notuniq") + interfaces["vmx3"] = dict() + interfaces["vmx0.100"] = dict(descr="uniq") + interfaces["vmx1.1100"] = dict() return interfaces - return ['autoselect'] + return ["autoselect"] super(TestPFSenseInterfaceModule, self).setUp() @@ -45,7 +47,7 @@ def php_mock(command): self.php.side_effect = php_mock def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseInterfaceModule, self).tearDown() self.php.stop() @@ -54,128 +56,152 @@ def tearDown(self): # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated interface xml definition """ + """get the generated interface xml definition""" elt_filter = {} - elt_filter['descr'] = obj['descr'] + elt_filter["descr"] = obj["descr"] - return self.assert_has_xml_tag('interfaces', elt_filter, absent=absent) + return self.assert_has_xml_tag("interfaces", elt_filter, absent=absent) def check_target_elt(self, obj, target_elt): - """ test the xml definition of interface """ - if 'interface_descr' in obj and obj['interface_descr'] == 'uniq': - obj['interface'] = 'vmx0.100' - self.assert_xml_elt_equal(target_elt, 'if', self.unalias_interface(obj['interface'], physical=True)) + """test the xml definition of interface""" + if "interface_descr" in obj and obj["interface_descr"] == "uniq": + obj["interface"] = "vmx0.100" + self.assert_xml_elt_equal( + target_elt, "if", self.unalias_interface(obj["interface"], physical=True) + ) # bools - if obj.get('enable'): - self.assert_xml_elt_is_none_or_empty(target_elt, 'enable') + if obj.get("enable"): + self.assert_xml_elt_is_none_or_empty(target_elt, "enable") else: - self.assert_not_find_xml_elt(target_elt, 'enable') + self.assert_not_find_xml_elt(target_elt, "enable") - if obj.get('blockpriv'): - self.assert_xml_elt_equal(target_elt, 'blockpriv', '') + if obj.get("blockpriv"): + self.assert_xml_elt_equal(target_elt, "blockpriv", "") else: - self.assert_not_find_xml_elt(target_elt, 'blockpriv') + self.assert_not_find_xml_elt(target_elt, "blockpriv") - if obj.get('blockbogons'): - self.assert_xml_elt_equal(target_elt, 'blockbogons', '') + if obj.get("blockbogons"): + self.assert_xml_elt_equal(target_elt, "blockbogons", "") else: - self.assert_not_find_xml_elt(target_elt, 'blockbogons') + self.assert_not_find_xml_elt(target_elt, "blockbogons") # ipv4 type related - if obj.get('ipv4_type') is None or obj.get('ipv4_type') == 'none': - self.assert_not_find_xml_elt(target_elt, 'ipaddr') - self.assert_not_find_xml_elt(target_elt, 'subnet') - self.assert_not_find_xml_elt(target_elt, 'gateway') - elif obj.get('ipv4_type') == 'static': - if obj.get('ipv4_address'): - self.assert_xml_elt_equal(target_elt, 'ipaddr', obj['ipv4_address']) - if obj.get('ipv4_prefixlen'): - self.assert_xml_elt_equal(target_elt, 'subnet', str(obj['ipv4_prefixlen'])) - if obj.get('ipv4_gateway'): - self.assert_xml_elt_equal(target_elt, 'gateway', obj['ipv4_gateway']) + if obj.get("ipv4_type") is None or obj.get("ipv4_type") == "none": + self.assert_not_find_xml_elt(target_elt, "ipaddr") + self.assert_not_find_xml_elt(target_elt, "subnet") + self.assert_not_find_xml_elt(target_elt, "gateway") + elif obj.get("ipv4_type") == "static": + if obj.get("ipv4_address"): + self.assert_xml_elt_equal(target_elt, "ipaddr", obj["ipv4_address"]) + if obj.get("ipv4_prefixlen"): + self.assert_xml_elt_equal( + target_elt, "subnet", str(obj["ipv4_prefixlen"]) + ) + if obj.get("ipv4_gateway"): + self.assert_xml_elt_equal(target_elt, "gateway", obj["ipv4_gateway"]) # ipv6 type related - if obj.get('ipv6_type') is None or obj.get('ipv6_type') in ['none']: - self.assert_not_find_xml_elt(target_elt, 'ipaddrv6') - self.assert_not_find_xml_elt(target_elt, 'subnetv6') - self.assert_not_find_xml_elt(target_elt, 'gatewayv6') - elif obj.get('ipv6_type') == 'slaac': - self.assert_xml_elt_equal(target_elt, 'ipaddrv6', 'slaac') - self.assert_not_find_xml_elt(target_elt, 'subnetv6') - self.assert_not_find_xml_elt(target_elt, 'gatewayv6') - elif obj.get('ipv6_type') == 'static': - if obj.get('ipv6_address'): - self.assert_xml_elt_equal(target_elt, 'ipaddrv6', obj['ipv6_address']) - if obj.get('ipv6_prefixlen'): - self.assert_xml_elt_equal(target_elt, 'subnetv6', str(obj['ipv6_prefixlen'])) - if obj.get('ipv6_gateway'): - self.assert_xml_elt_equal(target_elt, 'gatewayv6', obj['ipv6_gateway']) + if obj.get("ipv6_type") is None or obj.get("ipv6_type") in ["none"]: + self.assert_not_find_xml_elt(target_elt, "ipaddrv6") + self.assert_not_find_xml_elt(target_elt, "subnetv6") + self.assert_not_find_xml_elt(target_elt, "gatewayv6") + elif obj.get("ipv6_type") == "slaac": + self.assert_xml_elt_equal(target_elt, "ipaddrv6", "slaac") + self.assert_not_find_xml_elt(target_elt, "subnetv6") + self.assert_not_find_xml_elt(target_elt, "gatewayv6") + elif obj.get("ipv6_type") == "static": + if obj.get("ipv6_address"): + self.assert_xml_elt_equal(target_elt, "ipaddrv6", obj["ipv6_address"]) + if obj.get("ipv6_prefixlen"): + self.assert_xml_elt_equal( + target_elt, "subnetv6", str(obj["ipv6_prefixlen"]) + ) + if obj.get("ipv6_gateway"): + self.assert_xml_elt_equal(target_elt, "gatewayv6", obj["ipv6_gateway"]) # mac, mss, mtu - if obj.get('mac'): - self.assert_xml_elt_equal(target_elt, 'spoofmac', obj['mac']) + if obj.get("mac"): + self.assert_xml_elt_equal(target_elt, "spoofmac", obj["mac"]) else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'spoofmac') + self.assert_xml_elt_is_none_or_empty(target_elt, "spoofmac") - if obj.get('mtu'): - self.assert_xml_elt_equal(target_elt, 'mtu', str(obj['mtu'])) + if obj.get("mtu"): + self.assert_xml_elt_equal(target_elt, "mtu", str(obj["mtu"])) else: - self.assert_not_find_xml_elt(target_elt, 'mtu') + self.assert_not_find_xml_elt(target_elt, "mtu") - if obj.get('mss'): - self.assert_xml_elt_equal(target_elt, 'mss', str(obj['mss'])) + if obj.get("mss"): + self.assert_xml_elt_equal(target_elt, "mss", str(obj["mss"])) else: - self.assert_not_find_xml_elt(target_elt, 'mss') + self.assert_not_find_xml_elt(target_elt, "mss") ############## # tests # def test_interface_create_no_address(self): - """ test creation of a new interface with no address """ - interface = dict(descr='VOICE', interface='vmx0.100') + """test creation of a new interface with no address""" + interface = dict(descr="VOICE", interface="vmx0.100") command = "create interface 'VOICE', port='vmx0.100'" self.do_module_test(interface, command=command) def test_interface_create_by_descr(self): - """ test creation of a new interface with interface_descr """ - interface = dict(descr='VOICE', interface_descr='uniq') + """test creation of a new interface with interface_descr""" + interface = dict(descr="VOICE", interface_descr="uniq") command = "create interface 'VOICE', port='vmx0.100'" self.do_module_test(interface, command=command) def test_interface_create_static(self): - """ test creation of a new interface with a static ip """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv4_type='static', ipv4_address='10.20.30.40', ipv4_prefixlen=24) + """test creation of a new interface with a static ip""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv4_type="static", + ipv4_address="10.20.30.40", + ipv4_prefixlen=24, + ) command = "create interface 'VOICE', port='vmx0.100', ipv4_type='static', ipv4_address='10.20.30.40', ipv4_prefixlen='24'" self.do_module_test(interface, command=command) def test_interface_create_static_ipv6(self): - """ test creation of a new interface with a static ipv6 """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv6_type='static', ipv6_address='3001::2001:22', ipv6_prefixlen=56) + """test creation of a new interface with a static ipv6""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv6_type="static", + ipv6_address="3001::2001:22", + ipv6_prefixlen=56, + ) command = "create interface 'VOICE', port='vmx0.100', ipv6_type='static', ipv6_address='3001::2001:22', ipv6_prefixlen='56'" self.do_module_test(interface, command=command) def test_interface_create_slaac(self): - """ test creation of a new interface with slaac """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv6_type='slaac') + """test creation of a new interface with slaac""" + interface = dict(descr="VOICE", interface="vmx0.100", ipv6_type="slaac") command = "create interface 'VOICE', port='vmx0.100', ipv6_type='slaac'" self.do_module_test(interface, command=command) def test_interface_create_none_mac_mtu_mss(self): - """ test creation of a new interface """ - interface = dict(descr='VOICE', interface='vmx0.100', mac='00:11:22:33:44:55', mtu=1500, mss=1100) + """test creation of a new interface""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + mac="00:11:22:33:44:55", + mtu=1500, + mss=1100, + ) command = "create interface 'VOICE', port='vmx0.100', mac='00:11:22:33:44:55', mtu='1500', mss='1100'" self.do_module_test(interface, command=command) def test_interface_delete(self): - """ test deletion of an interface """ - interface = dict(descr='vt1') + """test deletion of an interface""" + interface = dict(descr="vt1") command = "delete interface 'vt1'" self.do_module_test(interface, delete=True, command=command) def test_interface_delete_lan(self): - """ test deletion of an interface """ - interface = dict(descr='lan') + """test deletion of an interface""" + interface = dict(descr="lan") commands = [ "delete rule_separator 'test_separator', interface='lan'", "update rule 'floating_rule_2' on 'floating(lan,wan,lan_1100)' set interface='wan,lan_1100'", @@ -183,117 +209,210 @@ def test_interface_delete_lan(self): "delete rule 'antilock_out_1' on 'lan'", "delete rule 'antilock_out_2' on 'lan'", "delete rule 'antilock_out_3' on 'lan'", - "delete interface 'lan'" + "delete interface 'lan'", ] self.do_module_test(interface, delete=True, command=commands) def test_interface_delete_fails(self): - """ test deletion of an interface that is part of a group """ - interface = dict(descr='lan_1100') + """test deletion of an interface that is part of a group""" + interface = dict(descr="lan_1100") msg = "The interface is part of the group IFGROUP1. Please remove it from the group first." self.do_module_test(interface, delete=True, failed=True, msg=msg) def test_interface_update_noop(self): - """ test not updating a interface """ - interface = dict(descr='lan_1100', interface='vmx1.1100', enable=True, ipv4_type='static', ipv4_address='172.16.151.210', ipv4_prefixlen=24) + """test not updating a interface""" + interface = dict( + descr="lan_1100", + interface="vmx1.1100", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + ) self.do_module_test(interface, changed=False) def test_interface_update_name(self): - """ test updating interface name """ - interface = dict(descr='wlan_1100', interface='vmx1.1100', enable=True, ipv4_type='static', ipv4_address='172.16.151.210', ipv4_prefixlen=24) + """test updating interface name""" + interface = dict( + descr="wlan_1100", + interface="vmx1.1100", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + ) command = "update interface 'lan_1100' set interface='wlan_1100'" self.do_module_test(interface, changed=True, command=command) def test_interface_update_enable(self): - """ test disabling interface """ - interface = dict(descr='lan_1100', interface='vmx1.1100', enable=False, ipv4_type='static', ipv4_address='172.16.151.210', ipv4_prefixlen=24) + """test disabling interface""" + interface = dict( + descr="lan_1100", + interface="vmx1.1100", + enable=False, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + ) command = "update interface 'lan_1100' set enable=False" self.do_module_test(interface, changed=True, command=command) def test_interface_update_enable2(self): - """ test enabling interface """ - interface = dict(descr='vt1', interface='vmx3', enable=True) + """test enabling interface""" + interface = dict(descr="vt1", interface="vmx3", enable=True) command = "update interface 'vt1' set enable=True" self.do_module_test(interface, changed=True, command=command) def test_interface_update_mac(self): - """ test updating mac """ - interface = dict(descr='lan_1100', interface='vmx1.1100', enable=True, ipv4_type='static', - ipv4_address='172.16.151.210', ipv4_prefixlen=24, mac='00:11:22:33:44:55', ) + """test updating mac""" + interface = dict( + descr="lan_1100", + interface="vmx1.1100", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + mac="00:11:22:33:44:55", + ) command = "update interface 'lan_1100' set mac='00:11:22:33:44:55'" self.do_module_test(interface, changed=True, command=command) def test_interface_update_blocks(self): - """ test updating block fields """ - interface = dict(descr='lan_1100', interface='vmx1.1100', enable=True, ipv4_type='static', - ipv4_address='172.16.151.210', ipv4_prefixlen=24, blockpriv=True, blockbogons=True) + """test updating block fields""" + interface = dict( + descr="lan_1100", + interface="vmx1.1100", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + blockpriv=True, + blockbogons=True, + ) command = "update interface 'lan_1100' set blockpriv=True, blockbogons=True" self.do_module_test(interface, changed=True, command=command) def test_interface_error_used(self): - """ test error already used """ - interface = dict(descr='lan_1100', interface='vmx1', enable=True, ipv4_type='static', ipv4_address='172.16.151.210', ipv4_prefixlen=24) + """test error already used""" + interface = dict( + descr="lan_1100", + interface="vmx1", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + ) msg = "Port vmx1 is already in use on interface lan" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_gw(self): - """ test error no such gateway """ - interface = dict(descr='lan_1100', interface='vmx1.1100', enable=True, ipv4_type='static', - ipv4_address='172.16.151.210', ipv4_prefixlen=24, ipv4_gateway='voice_gw') + """test error no such gateway""" + interface = dict( + descr="lan_1100", + interface="vmx1.1100", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + ipv4_gateway="voice_gw", + ) msg = "Gateway voice_gw does not exist on lan_1100" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_if(self): - """ test error no such interface """ - interface = dict(descr='wlan_1100', interface='vmx1.1200', enable=True, ipv4_type='static', - ipv4_address='172.16.151.210', ipv4_prefixlen=24, ipv4_gateway='voice_gw') + """test error no such interface""" + interface = dict( + descr="wlan_1100", + interface="vmx1.1200", + enable=True, + ipv4_type="static", + ipv4_address="172.16.151.210", + ipv4_prefixlen=24, + ipv4_gateway="voice_gw", + ) msg = "vmx1.1200 can't be assigned. Interface may only be one the following: ['vmx0', 'vmx1', 'vmx2', 'vmx3', 'vmx0.100', 'vmx1.1100']" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_eq(self): - """ test error same ipv4 address """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv4_type='static', ipv4_address='192.168.1.242', ipv4_prefixlen=32) + """test error same ipv4 address""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv4_type="static", + ipv4_address="192.168.1.242", + ipv4_prefixlen=32, + ) msg = "IP address 192.168.1.242/32 is being used by or overlaps with: lan (192.168.1.242/24)" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_overlaps1(self): - """ test error same ipv4 address """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv4_type='static', ipv4_address='192.168.1.1', ipv4_prefixlen=30) + """test error same ipv4 address""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv4_type="static", + ipv4_address="192.168.1.1", + ipv4_prefixlen=30, + ) msg = "IP address 192.168.1.1/30 is being used by or overlaps with: lan (192.168.1.242/24)" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_overlaps2(self): - """ test error same ipv4 address """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv4_type='static', ipv4_address='192.168.1.1', ipv4_prefixlen=22) + """test error same ipv4 address""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv4_type="static", + ipv4_address="192.168.1.1", + ipv4_prefixlen=22, + ) msg = "IP address 192.168.1.1/22 is being used by or overlaps with: lan (192.168.1.242/24)" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_inet6_eq(self): - """ test error same ipv6 address """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv6_type='static', ipv6_address='2001::2001:22', ipv6_prefixlen=127) + """test error same ipv6 address""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv6_type="static", + ipv6_address="2001::2001:22", + ipv6_prefixlen=127, + ) msg = "IP address 2001::2001:22/127 is being used by or overlaps with: lan (2001::2001:22/64)" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_inet6_overlaps1(self): - """ test error same ipv6 address """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv6_type='static', ipv6_address='2001::2001:1', ipv6_prefixlen=64) + """test error same ipv6 address""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv6_type="static", + ipv6_address="2001::2001:1", + ipv6_prefixlen=64, + ) msg = "IP address 2001::2001:1/64 is being used by or overlaps with: lan (2001::2001:22/64)" self.do_module_test(interface, failed=True, msg=msg) def test_interface_error_inet6_overlaps2(self): - """ test error same ipv6 address """ - interface = dict(descr='VOICE', interface='vmx0.100', ipv6_type='static', ipv6_address='2001::2001', ipv6_prefixlen=56) + """test error same ipv6 address""" + interface = dict( + descr="VOICE", + interface="vmx0.100", + ipv6_type="static", + ipv6_address="2001::2001", + ipv6_prefixlen=56, + ) msg = "IP address 2001::2001/56 is being used by or overlaps with: lan (2001::2001:22/64)" self.do_module_test(interface, failed=True, msg=msg) def test_interface_delete_sub(self): - """ test delete sub interface """ - interface = dict(descr='lan_1200', interface='vmx1.1200') + """test delete sub interface""" + interface = dict(descr="lan_1200", interface="vmx1.1200") command = "delete interface 'lan_1200'" self.do_module_test(interface, delete=True, command=command) def test_interface_error_not_uniq(self): - """ test creation of a new interface with interface_descr """ - interface = dict(descr='VOICE', interface_descr='notuniq') + """test creation of a new interface with interface_descr""" + interface = dict(descr="VOICE", interface_descr="notuniq") msg = 'Multiple interfaces found for "notuniq"' self.do_module_test(interface, failed=True, msg=msg) diff --git a/tests/unit/plugins/modules/test_pfsense_interface_group.py b/tests/unit/plugins/modules/test_pfsense_interface_group.py index 4a57c440..9a2f99be 100644 --- a/tests/unit/plugins/modules/test_pfsense_interface_group.py +++ b/tests/unit/plugins/modules/test_pfsense_interface_group.py @@ -2,7 +2,8 @@ # Copyright: (c) 2024, Orioni Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -12,33 +13,34 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_interface_group -from ansible_collections.pfsensible.core.plugins.module_utils.interface_group import PFSenseInterfaceGroupModule +from ansible_collections.pfsensible.core.plugins.module_utils.interface_group import ( + PFSenseInterfaceGroupModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseInterfaceGroupModule(TestPFSenseModule): - module = pfsense_interface_group def __init__(self, *args, **kwargs): super(TestPFSenseInterfaceGroupModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_interface_config.xml' + self.config_file = "pfsense_interface_config.xml" self.pfmodule = PFSenseInterfaceGroupModule def setUp(self): - """ mocking up """ + """mocking up""" def php_mock(command): - if 'get_interface_list' in command: + if "get_interface_list" in command: interfaces = dict() - interfaces['vmx0'] = dict() - interfaces['vmx1'] = dict(descr='notuniq') - interfaces['vmx2'] = dict(descr='notuniq') - interfaces['vmx3'] = dict() - interfaces['vmx0.100'] = dict(descr='uniq') - interfaces['vmx1.1100'] = dict() + interfaces["vmx0"] = dict() + interfaces["vmx1"] = dict(descr="notuniq") + interfaces["vmx2"] = dict(descr="notuniq") + interfaces["vmx3"] = dict() + interfaces["vmx0.100"] = dict(descr="uniq") + interfaces["vmx1.1100"] = dict() return interfaces - return ['autoselect'] + return ["autoselect"] super(TestPFSenseInterfaceGroupModule, self).setUp() @@ -46,7 +48,7 @@ def php_mock(command): self.php.side_effect = php_mock def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseInterfaceGroupModule, self).tearDown() self.php.stop() @@ -55,78 +57,86 @@ def tearDown(self): # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated interface group xml definition """ + """get the generated interface group xml definition""" elt_filter = {} - elt_filter['ifname'] = obj['name'] + elt_filter["ifname"] = obj["name"] - return self.assert_has_xml_tag('ifgroups', elt_filter, absent=absent) + return self.assert_has_xml_tag("ifgroups", elt_filter, absent=absent) def check_target_elt(self, obj, target_elt): - """ test the xml definition of interface group """ + """test the xml definition of interface group""" # descr, members - if obj.get('descr'): - self.assert_xml_elt_equal(target_elt, 'descr', obj['descr']) + if obj.get("descr"): + self.assert_xml_elt_equal(target_elt, "descr", obj["descr"]) else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'descr') + self.assert_xml_elt_is_none_or_empty(target_elt, "descr") - if obj.get('members'): - self.assert_xml_elt_equal(target_elt, 'members', ' '.join(obj['members'])) + if obj.get("members"): + self.assert_xml_elt_equal(target_elt, "members", " ".join(obj["members"])) else: - self.assert_not_find_xml_elt(target_elt, 'members') + self.assert_not_find_xml_elt(target_elt, "members") ############## # tests # def test_interface_group_create(self): - """ test creation of a new interface group """ - interface_group = dict(name='IFGROUP2', members=['wan', 'lan']) + """test creation of a new interface group""" + interface_group = dict(name="IFGROUP2", members=["wan", "lan"]) command = "create interface_group 'IFGROUP2', members='wan lan'" self.do_module_test(interface_group, command=command) def test_interface_group_create_with_descr(self): - """ test creation of a new interface group with a description """ - interface_group = dict(name='IFGROUP2', members=['wan', 'lan'], descr='Primary interfaces') + """test creation of a new interface group with a description""" + interface_group = dict( + name="IFGROUP2", members=["wan", "lan"], descr="Primary interfaces" + ) command = "create interface_group 'IFGROUP2', descr='Primary interfaces', members='wan lan'" self.do_module_test(interface_group, command=command) def test_interface_group_delete(self): - """ test deletion of an interface group """ - interface_group = dict(name='IFGROUP1', state='absent') + """test deletion of an interface group""" + interface_group = dict(name="IFGROUP1", state="absent") command = "delete interface_group 'IFGROUP1'" self.do_module_test(interface_group, delete=True, command=command) def test_interface_group_update_noop(self): - """ test not updating a interface group """ - interface_group = dict(name='IFGROUP1', members=['opt1', 'opt3']) + """test not updating a interface group""" + interface_group = dict(name="IFGROUP1", members=["opt1", "opt3"]) self.do_module_test(interface_group, changed=False) def test_interface_group_update_descr(self): - """ test updating interface group description """ - interface_group = dict(name='IFGROUP1', members=['opt1', 'opt3'], descr='Opt Interfaces') + """test updating interface group description""" + interface_group = dict( + name="IFGROUP1", members=["opt1", "opt3"], descr="Opt Interfaces" + ) command = "update interface_group 'IFGROUP1' set descr='Opt Interfaces'" self.do_module_test(interface_group, changed=True, command=command) def test_interface_group_update_members(self): - """ test updating interface group members """ - interface_group = dict(name='IFGROUP1', members=['opt1', 'opt2']) + """test updating interface group members""" + interface_group = dict(name="IFGROUP1", members=["opt1", "opt2"]) command = "update interface_group 'IFGROUP1' set members='opt1 opt2'" self.do_module_test(interface_group, changed=True, command=command) def test_interface_group_error_no_members(self): - """ test error no members specified """ - interface_group = dict(name='IFGROUP2', descr='Primary interfaces') + """test error no members specified""" + interface_group = dict(name="IFGROUP2", descr="Primary interfaces") msg = "state is present but all of the following are missing: members" self.do_module_test(interface_group, failed=True, msg=msg) def test_interface_group_error_member_does_not_exist(self): - """ test error member does not exist """ - interface_group = dict(name='IFGROUP2', members=['blah'], descr='Primary interfaces') + """test error member does not exist""" + interface_group = dict( + name="IFGROUP2", members=["blah"], descr="Primary interfaces" + ) msg = 'Unknown interface name "blah".' self.do_module_test(interface_group, failed=True, msg=msg) def test_interface_group_error_members_not_uniq(self): - """ test error member does not exist """ - interface_group = dict(name='IFGROUP2', members=['opt1', 'opt1'], descr='Primary interfaces') - msg = 'List of members is not unique.' + """test error member does not exist""" + interface_group = dict( + name="IFGROUP2", members=["opt1", "opt1"], descr="Primary interfaces" + ) + msg = "List of members is not unique." self.do_module_test(interface_group, failed=True, msg=msg) diff --git a/tests/unit/plugins/modules/test_pfsense_ipsec.py b/tests/unit/plugins/modules/test_pfsense_ipsec.py index a6de1e32..d807eba1 100644 --- a/tests/unit/plugins/modules/test_pfsense_ipsec.py +++ b/tests/unit/plugins/modules/test_pfsense_ipsec.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,12 +12,13 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_ipsec -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec import PFSenseIpsecModule +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec import ( + PFSenseIpsecModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseIpsecModule(TestPFSenseModule): - module = pfsense_ipsec def __init__(self, *args, **kwargs): @@ -24,137 +26,147 @@ def __init__(self, *args, **kwargs): self.pfmodule = PFSenseIpsecModule def get_config_file(self): - """ get config file """ + """get config file""" - return 'pfsense_ipsec_config.xml' + return "pfsense_ipsec_config.xml" ############## # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated ipsec xml definition """ + """get the generated ipsec xml definition""" elt_filter = {} - elt_filter['descr'] = obj['descr'] + elt_filter["descr"] = obj["descr"] - return self.assert_has_xml_tag('ipsec', elt_filter, absent=absent) + return self.assert_has_xml_tag("ipsec", elt_filter, absent=absent) @staticmethod def caref(descr): - """ return refid for ca """ - if descr == 'test ca': - return '5db509cfed87d' - if descr == 'test ca copy': - return '5db509cfed87e' - return '' + """return refid for ca""" + if descr == "test ca": + return "5db509cfed87d" + if descr == "test ca copy": + return "5db509cfed87e" + return "" @staticmethod def certref(descr): - """ return refid for cert """ - if descr == 'webConfigurator default (5c00e5f9029df)': - return '5c00e5f9029df' - if descr == 'webConfigurator default copy': - return '5c00e5f9029de' - return '' + """return refid for cert""" + if descr == "webConfigurator default (5c00e5f9029df)": + return "5c00e5f9029df" + if descr == "webConfigurator default copy": + return "5c00e5f9029de" + return "" def check_target_elt(self, obj, target_elt): - """ test the xml definition of ipsec elt """ + """test the xml definition of ipsec elt""" # bools - if obj.get('disabled'): - self.assert_xml_elt_is_none_or_empty(target_elt, 'disabled') + if obj.get("disabled"): + self.assert_xml_elt_is_none_or_empty(target_elt, "disabled") else: - self.assert_not_find_xml_elt(target_elt, 'disabled') + self.assert_not_find_xml_elt(target_elt, "disabled") - self.check_param_bool(obj, target_elt, 'gw_duplicates') - self.check_param_equal_or_not_find(obj, target_elt, 'nattport') - for param in ['rand_time', 'reauth_time', 'rekey_time']: + self.check_param_bool(obj, target_elt, "gw_duplicates") + self.check_param_equal_or_not_find(obj, target_elt, "nattport") + for param in ["rand_time", "reauth_time", "rekey_time"]: if obj.get(param): - self.check_param_equal(obj, target_elt, 'rekey_time') - self.check_param_equal(obj, target_elt, 'reauth_time') - self.check_param_equal(obj, target_elt, 'rand_time') + self.check_param_equal(obj, target_elt, "rekey_time") + self.check_param_equal(obj, target_elt, "reauth_time") + self.check_param_equal(obj, target_elt, "rand_time") # Added in 2.5.2 - if obj.get('startaction'): - self.assert_xml_elt_equal(target_elt, 'startaction', obj['startaction']) - if obj.get('closeaction'): - self.assert_xml_elt_equal(target_elt, 'closeaction', obj['closeaction']) + if obj.get("startaction"): + self.assert_xml_elt_equal(target_elt, "startaction", obj["startaction"]) + if obj.get("closeaction"): + self.assert_xml_elt_equal(target_elt, "closeaction", obj["closeaction"]) - if obj.get('disable_reauth'): - self.assert_xml_elt_is_none_or_empty(target_elt, 'reauth_enable') + if obj.get("disable_reauth"): + self.assert_xml_elt_is_none_or_empty(target_elt, "reauth_enable") else: - self.assert_not_find_xml_elt(target_elt, 'reauth_enable') + self.assert_not_find_xml_elt(target_elt, "reauth_enable") - if obj.get('splitconn'): - self.assert_xml_elt_is_none_or_empty(target_elt, 'splitconn') + if obj.get("splitconn"): + self.assert_xml_elt_is_none_or_empty(target_elt, "splitconn") else: - self.assert_not_find_xml_elt(target_elt, 'splitconn') + self.assert_not_find_xml_elt(target_elt, "splitconn") - if obj.get('enable_dpd') is None or obj.get('enable_dpd'): - if obj.get('dpd_delay') is not None: - self.assert_xml_elt_equal(target_elt, 'dpd_delay', obj['dpd_delay']) + if obj.get("enable_dpd") is None or obj.get("enable_dpd"): + if obj.get("dpd_delay") is not None: + self.assert_xml_elt_equal(target_elt, "dpd_delay", obj["dpd_delay"]) else: - self.assert_xml_elt_equal(target_elt, 'dpd_delay', '10') + self.assert_xml_elt_equal(target_elt, "dpd_delay", "10") - if obj.get('dpd_maxfail') is not None: - self.assert_xml_elt_equal(target_elt, 'dpd_maxfail', obj['dpd_maxfail']) + if obj.get("dpd_maxfail") is not None: + self.assert_xml_elt_equal(target_elt, "dpd_maxfail", obj["dpd_maxfail"]) else: - self.assert_xml_elt_equal(target_elt, 'dpd_maxfail', '5') + self.assert_xml_elt_equal(target_elt, "dpd_maxfail", "5") else: - self.assert_not_find_xml_elt(target_elt, 'dpd_delay') - self.assert_not_find_xml_elt(target_elt, 'dpd_maxfail') + self.assert_not_find_xml_elt(target_elt, "dpd_delay") + self.assert_not_find_xml_elt(target_elt, "dpd_maxfail") - if obj.get('mobike'): - self.assert_xml_elt_equal(target_elt, 'mobike', obj['mobike']) + if obj.get("mobike"): + self.assert_xml_elt_equal(target_elt, "mobike", obj["mobike"]) # iketype & mode - self.assert_xml_elt_equal(target_elt, 'iketype', obj['iketype']) - if obj.get('mode') is not None: - self.assert_xml_elt_equal(target_elt, 'mode', obj['mode']) + self.assert_xml_elt_equal(target_elt, "iketype", obj["iketype"]) + if obj.get("mode") is not None: + self.assert_xml_elt_equal(target_elt, "mode", obj["mode"]) - if obj.get('nat_traversal') is not None: - self.assert_xml_elt_equal(target_elt, 'nat_traversal', obj['nat_traversal']) + if obj.get("nat_traversal") is not None: + self.assert_xml_elt_equal(target_elt, "nat_traversal", obj["nat_traversal"]) else: - self.assert_xml_elt_equal(target_elt, 'nat_traversal', 'on') + self.assert_xml_elt_equal(target_elt, "nat_traversal", "on") # auth - self.assert_xml_elt_equal(target_elt, 'authentication_method', obj['authentication_method']) - if obj['authentication_method'] == 'rsasig': - self.assert_xml_elt_equal(target_elt, 'certref', self.certref(obj['certificate'])) - self.assert_xml_elt_equal(target_elt, 'caref', self.caref(obj['certificate_authority'])) - self.assert_xml_elt_is_none_or_empty(target_elt, 'pre-shared-key') + self.assert_xml_elt_equal( + target_elt, "authentication_method", obj["authentication_method"] + ) + if obj["authentication_method"] == "rsasig": + self.assert_xml_elt_equal( + target_elt, "certref", self.certref(obj["certificate"]) + ) + self.assert_xml_elt_equal( + target_elt, "caref", self.caref(obj["certificate_authority"]) + ) + self.assert_xml_elt_is_none_or_empty(target_elt, "pre-shared-key") else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'certref') - self.assert_xml_elt_is_none_or_empty(target_elt, 'caref') - self.assert_xml_elt_equal(target_elt, 'pre-shared-key', obj['preshared_key']) + self.assert_xml_elt_is_none_or_empty(target_elt, "certref") + self.assert_xml_elt_is_none_or_empty(target_elt, "caref") + self.assert_xml_elt_equal( + target_elt, "pre-shared-key", obj["preshared_key"] + ) # ids - if obj.get('myid_type') is not None: - self.assert_xml_elt_equal(target_elt, 'myid_type', obj['myid_type']) + if obj.get("myid_type") is not None: + self.assert_xml_elt_equal(target_elt, "myid_type", obj["myid_type"]) else: - self.assert_xml_elt_equal(target_elt, 'myid_type', 'myaddress') - if obj.get('myid_data') is not None: - self.assert_xml_elt_equal(target_elt, 'myid_data', obj['myid_data']) + self.assert_xml_elt_equal(target_elt, "myid_type", "myaddress") + if obj.get("myid_data") is not None: + self.assert_xml_elt_equal(target_elt, "myid_data", obj["myid_data"]) - if obj.get('peerid_type') is not None: - self.assert_xml_elt_equal(target_elt, 'peerid_type', obj['peerid_type']) + if obj.get("peerid_type") is not None: + self.assert_xml_elt_equal(target_elt, "peerid_type", obj["peerid_type"]) else: - self.assert_xml_elt_equal(target_elt, 'peerid_type', 'peeraddress') - if obj.get('peerid_data') is not None: - self.assert_xml_elt_equal(target_elt, 'peerid_data', obj['peerid_data']) + self.assert_xml_elt_equal(target_elt, "peerid_type", "peeraddress") + if obj.get("peerid_data") is not None: + self.assert_xml_elt_equal(target_elt, "peerid_data", obj["peerid_data"]) # misc - self.assert_xml_elt_equal(target_elt, 'interface', self.unalias_interface(obj['interface'])) + self.assert_xml_elt_equal( + target_elt, "interface", self.unalias_interface(obj["interface"]) + ) - if obj.get('protocol') is not None: - self.assert_xml_elt_equal(target_elt, 'protocol', obj['protocol']) + if obj.get("protocol") is not None: + self.assert_xml_elt_equal(target_elt, "protocol", obj["protocol"]) else: - self.assert_xml_elt_equal(target_elt, 'protocol', 'inet') - self.assert_xml_elt_equal(target_elt, 'remote-gateway', obj['remote_gateway']) + self.assert_xml_elt_equal(target_elt, "protocol", "inet") + self.assert_xml_elt_equal(target_elt, "remote-gateway", obj["remote_gateway"]) - if obj.get('lifetime') is not None: - self.assert_xml_elt_equal(target_elt, 'lifetime', obj['lifetime']) + if obj.get("lifetime") is not None: + self.assert_xml_elt_equal(target_elt, "lifetime", obj["lifetime"]) else: - self.assert_xml_elt_equal(target_elt, 'lifetime', '28800') + self.assert_xml_elt_equal(target_elt, "lifetime", "28800") def strip_commands(self, commands): commands = commands.replace("margintime='', ", "") @@ -166,113 +178,193 @@ def strip_commands(self, commands): # tests # def test_ipsec_create_ikev2(self): - """ test creation of a new ipsec tunnel with 2.5.2 params """ + """test creation of a new ipsec tunnel with 2.5.2 params""" ipsec = dict( - descr='new_tunnel', interface='lan_100', remote_gateway='1.2.3.4', nattport=4501, iketype='ikev2', - authentication_method='pre_shared_key', preshared_key='1234', gw_duplicates=True, rekey_time=2500, reauth_time=2600, rand_time=2700) + descr="new_tunnel", + interface="lan_100", + remote_gateway="1.2.3.4", + nattport=4501, + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="1234", + gw_duplicates=True, + rekey_time=2500, + reauth_time=2600, + rand_time=2700, + ) command = ( "create ipsec 'new_tunnel', iketype='ikev2', protocol='inet', interface='lan_100', remote_gateway='1.2.3.4', nattport='4501', " "authentication_method='pre_shared_key', preshared_key='1234', myid_type='myaddress', peerid_type='peeraddress', lifetime='28800', " "rekey_time='2500', reauth_time='2600', rand_time='2700', " - "mobike='off', gw_duplicates=True, startaction='', closeaction='', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'") + "mobike='off', gw_duplicates=True, startaction='', closeaction='', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'" + ) self.do_module_test(ipsec, command=command) def test_ipsec_create_ikev1(self): - """ test creation of a new ipsec tunnel """ + """test creation of a new ipsec tunnel""" ipsec = dict( - descr='new_tunnel', interface='lan_100', remote_gateway='1.2.3.4', iketype='ikev1', - authentication_method='pre_shared_key', preshared_key='1234', mode='main', startaction='none', closeaction='none') + descr="new_tunnel", + interface="lan_100", + remote_gateway="1.2.3.4", + iketype="ikev1", + authentication_method="pre_shared_key", + preshared_key="1234", + mode="main", + startaction="none", + closeaction="none", + ) command = ( "create ipsec 'new_tunnel', iketype='ikev1', mode='main', protocol='inet', interface='lan_100', remote_gateway='1.2.3.4', " "authentication_method='pre_shared_key', preshared_key='1234', myid_type='myaddress', peerid_type='peeraddress', lifetime='28800', " - "disable_rekey=False, margintime='', startaction='none', closeaction='none', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'") + "disable_rekey=False, margintime='', startaction='none', closeaction='none', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'" + ) self.do_module_test(ipsec, command=command) def test_ipsec_create_vip_descr(self): - """ test creation of a new ipsec tunnel with vip: interface name """ + """test creation of a new ipsec tunnel with vip: interface name""" ipsec = dict( - descr='new_tunnel', interface='vip:WAN CARP', remote_gateway='1.2.3.4', iketype='ikev1', - authentication_method='pre_shared_key', preshared_key='1234', mode='main', startaction='start', closeaction='start') + descr="new_tunnel", + interface="vip:WAN CARP", + remote_gateway="1.2.3.4", + iketype="ikev1", + authentication_method="pre_shared_key", + preshared_key="1234", + mode="main", + startaction="start", + closeaction="start", + ) command = ( "create ipsec 'new_tunnel', iketype='ikev1', mode='main', protocol='inet', interface='vip:WAN CARP', remote_gateway='1.2.3.4', " "authentication_method='pre_shared_key', preshared_key='1234', myid_type='myaddress', peerid_type='peeraddress', lifetime='28800', " "disable_rekey=False, margintime='', startaction='start', closeaction='start', " - "nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'") + "nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'" + ) self.do_module_test(ipsec, command=command) def test_ipsec_create_vip_subnet(self): - """ test creation of a new ipsec tunnel with vip: interface address """ + """test creation of a new ipsec tunnel with vip: interface address""" ipsec = dict( - descr='new_tunnel', interface='vip:151.25.19.11', remote_gateway='1.2.3.4', iketype='ikev1', - authentication_method='pre_shared_key', preshared_key='1234', mode='main', startaction='trap', closeaction='trap') + descr="new_tunnel", + interface="vip:151.25.19.11", + remote_gateway="1.2.3.4", + iketype="ikev1", + authentication_method="pre_shared_key", + preshared_key="1234", + mode="main", + startaction="trap", + closeaction="trap", + ) command = ( "create ipsec 'new_tunnel', iketype='ikev1', mode='main', protocol='inet', interface='vip:151.25.19.11', remote_gateway='1.2.3.4', " "authentication_method='pre_shared_key', preshared_key='1234', myid_type='myaddress', peerid_type='peeraddress', lifetime='28800', " - "disable_rekey=False, margintime='', startaction='trap', closeaction='trap', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'") + "disable_rekey=False, margintime='', startaction='trap', closeaction='trap', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'" + ) self.do_module_test(ipsec, command=command) def test_ipsec_create_auto(self): - """ test creation of a new ipsec tunnel """ + """test creation of a new ipsec tunnel""" ipsec = dict( - descr='new_tunnel', interface='lan_100', remote_gateway='1.2.3.4', iketype='auto', - authentication_method='pre_shared_key', preshared_key='1234', mode='main') + descr="new_tunnel", + interface="lan_100", + remote_gateway="1.2.3.4", + iketype="auto", + authentication_method="pre_shared_key", + preshared_key="1234", + mode="main", + ) command = ( "create ipsec 'new_tunnel', iketype='auto', mode='main', protocol='inet', interface='lan_100', remote_gateway='1.2.3.4', " "authentication_method='pre_shared_key', preshared_key='1234', myid_type='myaddress', peerid_type='peeraddress', lifetime='28800', " - "disable_rekey=False, margintime='', startaction='', closeaction='', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'") + "disable_rekey=False, margintime='', startaction='', closeaction='', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'" + ) self.do_module_test(ipsec, command=command) def test_ipsec_delete(self): - """ test deletion of an ipsec """ - ipsec = dict(descr='test_tunnel', state='absent') + """test deletion of an ipsec""" + ipsec = dict(descr="test_tunnel", state="absent") command = "delete ipsec 'test_tunnel'" self.do_module_test(ipsec, delete=True, command=command) def test_ipsec_update_noop(self): - """ test not updating a ipsec """ + """test not updating a ipsec""" ipsec = dict( - descr='test_tunnel', interface='lan_100', remote_gateway='1.2.4.8', iketype='ikev2', - authentication_method='pre_shared_key', preshared_key='1234') + descr="test_tunnel", + interface="lan_100", + remote_gateway="1.2.4.8", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="1234", + ) self.do_module_test(ipsec, changed=False) def test_ipsec_update_ike(self): - """ test updating ike """ + """test updating ike""" ipsec = dict( - descr='test_tunnel', interface='lan_100', remote_gateway='1.2.4.8', iketype='ikev1', - authentication_method='pre_shared_key', preshared_key='1234', mode='main') + descr="test_tunnel", + interface="lan_100", + remote_gateway="1.2.4.8", + iketype="ikev1", + authentication_method="pre_shared_key", + preshared_key="1234", + mode="main", + ) command = "update ipsec 'test_tunnel' set iketype='ikev1', mode='main'" self.do_module_test(ipsec, command=command) def test_ipsec_update_gw(self): - """ test updating gw """ + """test updating gw""" ipsec = dict( - descr='test_tunnel', interface='lan_100', remote_gateway='1.2.3.5', iketype='ikev2', - authentication_method='pre_shared_key', preshared_key='1234') + descr="test_tunnel", + interface="lan_100", + remote_gateway="1.2.3.5", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="1234", + ) command = "update ipsec 'test_tunnel' set remote_gateway='1.2.3.5'" self.do_module_test(ipsec, command=command) def test_ipsec_update_auth(self): - """ test updating auth """ + """test updating auth""" ipsec = dict( - descr='test_tunnel', interface='lan_100', remote_gateway='1.2.4.8', iketype='ikev2', - authentication_method='rsasig', certificate='webConfigurator default (5c00e5f9029df)', certificate_authority='test ca') + descr="test_tunnel", + interface="lan_100", + remote_gateway="1.2.4.8", + iketype="ikev2", + authentication_method="rsasig", + certificate="webConfigurator default (5c00e5f9029df)", + certificate_authority="test ca", + ) command = ( "update ipsec 'test_tunnel' set authentication_method='rsasig', " - "certificate='webConfigurator default (5c00e5f9029df)', certificate_authority='test ca'") + "certificate='webConfigurator default (5c00e5f9029df)', certificate_authority='test ca'" + ) self.do_module_test(ipsec, command=command) def test_ipsec_update_cert(self): - """ test updating certificates """ + """test updating certificates""" ipsec = dict( - descr='test_tunnel2', interface='lan_100', remote_gateway='1.2.3.6', iketype='ikev2', - authentication_method='rsasig', certificate='webConfigurator default copy', certificate_authority='test ca copy') + descr="test_tunnel2", + interface="lan_100", + remote_gateway="1.2.3.6", + iketype="ikev2", + authentication_method="rsasig", + certificate="webConfigurator default copy", + certificate_authority="test ca copy", + ) command = "update ipsec 'test_tunnel2' set certificate='webConfigurator default copy', certificate_authority='test ca copy'" self.do_module_test(ipsec, command=command) def test_ipsec_duplicate_gw(self): - """ test using a duplicate gw """ + """test using a duplicate gw""" ipsec = dict( - descr='new_tunnel', interface='lan_100', remote_gateway='1.2.4.8', iketype='ikev1', - authentication_method='pre_shared_key', preshared_key='1234', mode='main') + descr="new_tunnel", + interface="lan_100", + remote_gateway="1.2.4.8", + iketype="ikev1", + authentication_method="pre_shared_key", + preshared_key="1234", + mode="main", + ) msg = 'The remote gateway "1.2.4.8" is already used by phase1 "test_tunnel".' self.do_module_test(ipsec, msg=msg, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_ipsec_aggregate.py b/tests/unit/plugins/modules/test_pfsense_ipsec_aggregate.py index eb8001df..e8e4697d 100644 --- a/tests/unit/plugins/modules/test_pfsense_ipsec_aggregate.py +++ b/tests/unit/plugins/modules/test_pfsense_ipsec_aggregate.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -10,7 +11,9 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import set_module_args +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + set_module_args, +) from ansible_collections.pfsensible.core.plugins.modules import pfsense_ipsec_aggregate from parameterized import parameterized @@ -18,57 +21,57 @@ class TestPFSenseIpsecAggregateModule(TestPFSenseModule): - module = pfsense_ipsec_aggregate def __init__(self, *args, **kwargs): super(TestPFSenseIpsecAggregateModule, self).__init__(*args, **kwargs) def get_config_file(self): - """ get config file """ + """get config file""" - return 'pfsense_ipsec_aggregate_config.xml' + return "pfsense_ipsec_aggregate_config.xml" def assert_find_ipsec(self, ipsec): - """ test if an ipsec tunnel exist """ + """test if an ipsec tunnel exist""" self.load_xml_result() - parent_tag = self.xml_result.find('ipsec') + parent_tag = self.xml_result.find("ipsec") if parent_tag is None: - self.fail('Unable to find tag ipsec') + self.fail("Unable to find tag ipsec") found = False for ipsec_elt in parent_tag: - if ipsec_elt.tag != 'phase1': + if ipsec_elt.tag != "phase1": continue - if ipsec_elt.find('descr').text == ipsec: + if ipsec_elt.find("descr").text == ipsec: found = True break if not found: - self.fail('Ipsec tunnel not found: ' + ipsec) + self.fail("Ipsec tunnel not found: " + ipsec) def assert_not_find_ipsec(self, ipsec): - """ test if an ipsec tunnel does not exist """ + """test if an ipsec tunnel does not exist""" self.load_xml_result() - parent_tag = self.xml_result.find('ipsec') + parent_tag = self.xml_result.find("ipsec") if parent_tag is None: - self.fail('Unable to find tag ipsec') + self.fail("Unable to find tag ipsec") found = False for ipsec_elt in parent_tag: - if ipsec_elt.tag != 'phase1': + if ipsec_elt.tag != "phase1": continue - if ipsec_elt.find('descr').text == ipsec: + if ipsec_elt.find("descr").text == ipsec: found = True break if found: - self.fail('Ipsec tunnel found: ' + ipsec) + self.fail("Ipsec tunnel found: " + ipsec) def strip_commands(self, commands): - """ remove old or new parameters """ + """remove old or new parameters""" + def strip_command(command): command = command.replace("margintime='', ", "") command = command.replace("disable_rekey=False, ", "") @@ -89,19 +92,37 @@ def strip_command(command): # we just test the output @parameterized.expand([["2.5.2"]]) def test_ipsec_aggregate_ipsecs(self, pfsense_version): - """ test creation of a some tunnels """ + """test creation of a some tunnels""" self.get_version.return_value = pfsense_version args = dict( purge_ipsecs=False, aggregated_ipsecs=[ - dict(descr='t1', interface='wan', remote_gateway='1.3.3.1', iketype='ikev2', authentication_method='pre_shared_key', preshared_key='azerty123'), - dict(descr='t2', interface='wan', remote_gateway='1.3.3.2', iketype='ikev2', authentication_method='pre_shared_key', preshared_key='qwerty123'), - dict(descr='test_tunnel2', state='absent'), dict( - descr='test_tunnel', interface='lan_100', remote_gateway='1.2.4.8', iketype='ikev2', - authentication_method='pre_shared_key', preshared_key='0123456789' + descr="t1", + interface="wan", + remote_gateway="1.3.3.1", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="azerty123", + ), + dict( + descr="t2", + interface="wan", + remote_gateway="1.3.3.2", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="qwerty123", + ), + dict(descr="test_tunnel2", state="absent"), + dict( + descr="test_tunnel", + interface="lan_100", + remote_gateway="1.2.4.8", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="0123456789", ), - ] + ], ) with set_module_args(args): result = self.execute_module(changed=True) @@ -119,25 +140,41 @@ def test_ipsec_aggregate_ipsecs(self, pfsense_version): "mobike='off', startaction='', closeaction='', nat_traversal='on', enable_dpd=True, dpd_delay='10', dpd_maxfail='5'" ) result_ipsecs.append("delete ipsec 'test_tunnel2'") - result_ipsecs.append("update ipsec 'test_tunnel' set preshared_key='0123456789'") + result_ipsecs.append( + "update ipsec 'test_tunnel' set preshared_key='0123456789'" + ) result_ipsecs = self.strip_commands(result_ipsecs) - self.assertEqual(result['result_ipsecs'], result_ipsecs) - self.assert_find_ipsec('t1') - self.assert_find_ipsec('t2') - self.assert_not_find_ipsec('test_tunnel2') - self.assert_find_ipsec('test_tunnel') + self.assertEqual(result["result_ipsecs"], result_ipsecs) + self.assert_find_ipsec("t1") + self.assert_find_ipsec("t2") + self.assert_not_find_ipsec("test_tunnel2") + self.assert_find_ipsec("test_tunnel") @parameterized.expand([["2.5.2"]]) def test_ipsec_aggregate_ipsecs_purge(self, pfsense_version): - """ test creation of a some tunnels with purge """ + """test creation of a some tunnels with purge""" self.get_version.return_value = pfsense_version args = dict( purge_ipsecs=True, aggregated_ipsecs=[ - dict(descr='t1', interface='wan', remote_gateway='1.3.3.1', iketype='ikev2', authentication_method='pre_shared_key', preshared_key='azerty123'), - dict(descr='t2', interface='wan', remote_gateway='1.3.3.2', iketype='ikev2', authentication_method='pre_shared_key', preshared_key='qwerty123'), - ] + dict( + descr="t1", + interface="wan", + remote_gateway="1.3.3.1", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="azerty123", + ), + dict( + descr="t2", + interface="wan", + remote_gateway="1.3.3.2", + iketype="ikev2", + authentication_method="pre_shared_key", + preshared_key="qwerty123", + ), + ], ) with set_module_args(args): result = self.execute_module(changed=True) @@ -158,58 +195,108 @@ def test_ipsec_aggregate_ipsecs_purge(self, pfsense_version): result_ipsecs.append("delete ipsec 'test_tunnel2'") result_ipsecs = self.strip_commands(result_ipsecs) - self.assertEqual(result['result_ipsecs'], result_ipsecs) - self.assert_find_ipsec('t1') - self.assert_find_ipsec('t2') - self.assert_not_find_ipsec('test_tunnel') - self.assert_not_find_ipsec('test_tunnel2') + self.assertEqual(result["result_ipsecs"], result_ipsecs) + self.assert_find_ipsec("t1") + self.assert_find_ipsec("t2") + self.assert_not_find_ipsec("test_tunnel") + self.assert_not_find_ipsec("test_tunnel2") @parameterized.expand([["2.5.2"]]) def test_ipsec_aggregate_proposals(self, pfsense_version): - """ test creation of a some proposals """ + """test creation of a some proposals""" self.get_version.return_value = pfsense_version args = dict( purge_ipsec_proposals=False, aggregated_ipsec_proposals=[ - dict(descr='test_tunnel', encryption='aes', key_length=128, hash='md5', dhgroup=14), - dict(descr='test_tunnel2', encryption='cast128', hash='sha512', dhgroup=14), - dict(descr='test_tunnel', encryption='aes', key_length=128, hash='sha256', dhgroup=14, state='absent'), - dict(descr='test_tunnel2', encryption='blowfish', key_length=256, hash='aesxcbc', dhgroup=14, state='absent'), - ] + dict( + descr="test_tunnel", + encryption="aes", + key_length=128, + hash="md5", + dhgroup=14, + ), + dict( + descr="test_tunnel2", + encryption="cast128", + hash="sha512", + dhgroup=14, + ), + dict( + descr="test_tunnel", + encryption="aes", + key_length=128, + hash="sha256", + dhgroup=14, + state="absent", + ), + dict( + descr="test_tunnel2", + encryption="blowfish", + key_length=256, + hash="aesxcbc", + dhgroup=14, + state="absent", + ), + ], ) with set_module_args(args): self.execute_module(changed=True) result = self.execute_module(changed=True) result_ipsec_proposals = [] - result_ipsec_proposals.append("create ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='md5', dhgroup='14', prf='sha256'") - result_ipsec_proposals.append("create ipsec_proposal 'test_tunnel2', encryption='cast128', hash='sha512', dhgroup='14', prf='sha256'") - result_ipsec_proposals.append("delete ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='sha256', dhgroup='14', prf='sha256'") + result_ipsec_proposals.append( + "create ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='md5', dhgroup='14', prf='sha256'" + ) + result_ipsec_proposals.append( + "create ipsec_proposal 'test_tunnel2', encryption='cast128', hash='sha512', dhgroup='14', prf='sha256'" + ) + result_ipsec_proposals.append( + "delete ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='sha256', dhgroup='14', prf='sha256'" + ) result_ipsec_proposals.append( "delete ipsec_proposal 'test_tunnel2', encryption='blowfish', key_length=256, hash='aesxcbc', dhgroup='14', prf='sha256'" ) result_ipsec_proposals = self.strip_commands(result_ipsec_proposals) - self.assertEqual(result['result_ipsec_proposals'], result_ipsec_proposals) + self.assertEqual(result["result_ipsec_proposals"], result_ipsec_proposals) @parameterized.expand([["2.5.2"]]) def test_ipsec_aggregate_proposals_purge(self, pfsense_version): - """ test creation of a some proposals with purge """ + """test creation of a some proposals with purge""" self.get_version.return_value = pfsense_version args = dict( purge_ipsec_proposals=True, aggregated_ipsec_proposals=[ - dict(descr='test_tunnel', encryption='aes', key_length=128, hash='md5', dhgroup=14), - dict(descr='test_tunnel2', encryption='cast128', hash='sha512', dhgroup=14), - ] + dict( + descr="test_tunnel", + encryption="aes", + key_length=128, + hash="md5", + dhgroup=14, + ), + dict( + descr="test_tunnel2", + encryption="cast128", + hash="sha512", + dhgroup=14, + ), + ], ) with set_module_args(args): self.execute_module(changed=True) result = self.execute_module(changed=True) result_ipsec_proposals = [] - result_ipsec_proposals.append("create ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='md5', dhgroup='14', prf='sha256'") - result_ipsec_proposals.append("create ipsec_proposal 'test_tunnel2', encryption='cast128', hash='sha512', dhgroup='14', prf='sha256'") - result_ipsec_proposals.append("delete ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='sha256', dhgroup='14', prf='sha256'") - result_ipsec_proposals.append("delete ipsec_proposal 'test_tunnel', encryption='aes', key_length=256, hash='sha256', dhgroup='14', prf='sha256'") + result_ipsec_proposals.append( + "create ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='md5', dhgroup='14', prf='sha256'" + ) + result_ipsec_proposals.append( + "create ipsec_proposal 'test_tunnel2', encryption='cast128', hash='sha512', dhgroup='14', prf='sha256'" + ) + result_ipsec_proposals.append( + "delete ipsec_proposal 'test_tunnel', encryption='aes', key_length=128, hash='sha256', dhgroup='14', prf='sha256'" + ) + result_ipsec_proposals.append( + "delete ipsec_proposal 'test_tunnel', encryption='aes', key_length=256, hash='sha256', dhgroup='14', prf='sha256'" + ) result_ipsec_proposals.append( "delete ipsec_proposal 'test_tunnel', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup='14', prf='sha256'" ) @@ -217,8 +304,12 @@ def test_ipsec_aggregate_proposals_purge(self, pfsense_version): "delete ipsec_proposal 'test_tunnel', encryption='blowfish', key_length=256, hash='aesxcbc', dhgroup='14', prf='sha256'" ) - result_ipsec_proposals.append("delete ipsec_proposal 'test_tunnel2', encryption='aes', key_length=128, hash='sha256', dhgroup='14', prf='sha256'") - result_ipsec_proposals.append("delete ipsec_proposal 'test_tunnel2', encryption='aes', key_length=256, hash='sha256', dhgroup='14', prf='sha256'") + result_ipsec_proposals.append( + "delete ipsec_proposal 'test_tunnel2', encryption='aes', key_length=128, hash='sha256', dhgroup='14', prf='sha256'" + ) + result_ipsec_proposals.append( + "delete ipsec_proposal 'test_tunnel2', encryption='aes', key_length=256, hash='sha256', dhgroup='14', prf='sha256'" + ) result_ipsec_proposals.append( "delete ipsec_proposal 'test_tunnel2', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup='14', prf='sha256'" ) @@ -227,21 +318,47 @@ def test_ipsec_aggregate_proposals_purge(self, pfsense_version): ) result_ipsec_proposals = self.strip_commands(result_ipsec_proposals) - self.assertEqual(result['result_ipsec_proposals'], result_ipsec_proposals) + self.assertEqual(result["result_ipsec_proposals"], result_ipsec_proposals) def test_ipsec_aggregate_p2s(self): - """ test creation of a some p2s """ + """test creation of a some p2s""" args = dict( purge_ipsec_p2s=False, aggregated_ipsec_p2s=[ - dict(descr='p2_1', p1_descr='test_tunnel', mode='tunnel', local='1.2.3.4/24', remote='10.20.30.40/24', aes=True, aes_len='auto', sha256=True), - dict(descr='p2_2', p1_descr='test_tunnel', mode='tunnel', local='1.2.3.4/24', remote='10.20.30.50/24', aes=True, aes_len='auto', sha256=True), dict( - descr='one_p2', p1_descr='test_tunnel', mode='tunnel', local='lan', remote='10.20.30.60/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True' + descr="p2_1", + p1_descr="test_tunnel", + mode="tunnel", + local="1.2.3.4/24", + remote="10.20.30.40/24", + aes=True, + aes_len="auto", + sha256=True, + ), + dict( + descr="p2_2", + p1_descr="test_tunnel", + mode="tunnel", + local="1.2.3.4/24", + remote="10.20.30.50/24", + aes=True, + aes_len="auto", + sha256=True, + ), + dict( + descr="one_p2", + p1_descr="test_tunnel", + mode="tunnel", + local="lan", + remote="10.20.30.60/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", ), - dict(descr='another_p2', p1_descr='test_tunnel', state='absent') - ] + dict(descr="another_p2", p1_descr="test_tunnel", state="absent"), + ], ) with set_module_args(args): result = self.execute_module(changed=True) @@ -254,19 +371,39 @@ def test_ipsec_aggregate_p2s(self): "create ipsec_p2 'p2_2' on 'test_tunnel', disabled=False, mode='tunnel', local='1.2.3.4/24', remote='10.20.30.50/24', " "aes=True, aes_len='auto', sha256=True, pfsgroup='14', lifetime=3600" ) - result_ipsec_p2s.append("update ipsec_p2 'one_p2' on 'test_tunnel' set remote='10.20.30.60/24'") + result_ipsec_p2s.append( + "update ipsec_p2 'one_p2' on 'test_tunnel' set remote='10.20.30.60/24'" + ) result_ipsec_p2s.append("delete ipsec_p2 'another_p2' on 'test_tunnel'") - self.assertEqual(result['result_ipsec_p2s'], result_ipsec_p2s) + self.assertEqual(result["result_ipsec_p2s"], result_ipsec_p2s) def test_ipsec_aggregate_p2s_purge(self): - """ test creation of a some p2s with purge """ + """test creation of a some p2s with purge""" args = dict( purge_ipsec_p2s=True, aggregated_ipsec_p2s=[ - dict(descr='p2_1', p1_descr='test_tunnel', mode='tunnel', local='1.2.3.4/24', remote='10.20.30.40/24', aes=True, aes_len='auto', sha256=True), - dict(descr='p2_2', p1_descr='test_tunnel', mode='tunnel', local='1.2.3.4/24', remote='10.20.30.50/24', aes=True, aes_len='auto', sha256=True), - ] + dict( + descr="p2_1", + p1_descr="test_tunnel", + mode="tunnel", + local="1.2.3.4/24", + remote="10.20.30.40/24", + aes=True, + aes_len="auto", + sha256=True, + ), + dict( + descr="p2_2", + p1_descr="test_tunnel", + mode="tunnel", + local="1.2.3.4/24", + remote="10.20.30.50/24", + aes=True, + aes_len="auto", + sha256=True, + ), + ], ) with set_module_args(args): result = self.execute_module(changed=True) @@ -284,4 +421,4 @@ def test_ipsec_aggregate_p2s_purge(self): result_ipsec_p2s.append("delete ipsec_p2 'third_p2' on 'test_tunnel'") result_ipsec_p2s.append("delete ipsec_p2 'nat_p2' on 'test_tunnel'") - self.assertEqual(result['result_ipsec_p2s'], result_ipsec_p2s) + self.assertEqual(result["result_ipsec_p2s"], result_ipsec_p2s) diff --git a/tests/unit/plugins/modules/test_pfsense_ipsec_p2.py b/tests/unit/plugins/modules/test_pfsense_ipsec_p2.py index 41725b02..149bb8ed 100644 --- a/tests/unit/plugins/modules/test_pfsense_ipsec_p2.py +++ b/tests/unit/plugins/modules/test_pfsense_ipsec_p2.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,187 +12,199 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_ipsec_p2 -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import PFSenseIpsecP2Module +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_p2 import ( + PFSenseIpsecP2Module, +) from .pfsense_module import TestPFSenseModule class TestPFSenseIpsecP2Module(TestPFSenseModule): - module = pfsense_ipsec_p2 def __init__(self, *args, **kwargs): super(TestPFSenseIpsecP2Module, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_ipsec_p2_config.xml' + self.config_file = "pfsense_ipsec_p2_config.xml" self.pfmodule = PFSenseIpsecP2Module ############## # tests utils # def get_phase1_elt(self, descr, absent=False): - """ get phase1 """ + """get phase1""" elt_filter = {} - elt_filter['descr'] = descr - return self.assert_has_xml_tag('ipsec', elt_filter, absent=absent) + elt_filter["descr"] = descr + return self.assert_has_xml_tag("ipsec", elt_filter, absent=absent) def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated phase2 xml definition """ - phase1_elt = self.get_phase1_elt(obj['p1_descr']) + """get the generated phase2 xml definition""" + phase1_elt = self.get_phase1_elt(obj["p1_descr"]) elt_filter = {} - elt_filter['descr'] = obj['descr'] - elt_filter['ikeid'] = phase1_elt.find('ikeid').text - return self.assert_has_xml_tag('ipsec', elt_filter, absent=absent) + elt_filter["descr"] = obj["descr"] + elt_filter["ikeid"] = phase1_elt.find("ikeid").text + return self.assert_has_xml_tag("ipsec", elt_filter, absent=absent) @staticmethod def get_enc_elt(phase2_elt, enc_name): - """ get encryption """ + """get encryption""" for elt in phase2_elt: - if elt.tag != 'encryption-algorithm-option': + if elt.tag != "encryption-algorithm-option": continue - if elt.find('name').text == enc_name: + if elt.find("name").text == enc_name: return elt return None def check_enc(self, phase2, phase2_elt, enc_name, param_name): - """ check encryption """ + """check encryption""" enc_elt = self.get_enc_elt(phase2_elt, enc_name) if phase2.get(param_name): if enc_elt is None: - self.fail('Encryption named {0} not found'.format(enc_name)) - if phase2.get(param_name + '_len') is not None: - keylen_elt = enc_elt.find('keylen') + self.fail("Encryption named {0} not found".format(enc_name)) + if phase2.get(param_name + "_len") is not None: + keylen_elt = enc_elt.find("keylen") if keylen_elt is None: - self.fail('Key length not found for encryption named {0}'.format(enc_name)) - self.assertEqual(keylen_elt.text, phase2[param_name + '_len']) + self.fail( + "Key length not found for encryption named {0}".format(enc_name) + ) + self.assertEqual(keylen_elt.text, phase2[param_name + "_len"]) else: if enc_elt is not None: - self.fail('Encryption named {0} found'.format(enc_name)) + self.fail("Encryption named {0} found".format(enc_name)) @staticmethod def get_hash_elt(phase2_elt, hash_name): - """ get hash """ + """get hash""" for elt in phase2_elt: - if elt.tag != 'hash-algorithm-option': + if elt.tag != "hash-algorithm-option": continue if elt.text == hash_name: return elt return None def check_hash(self, phase2, phase2_elt, hash_name, param_name): - """ check hash """ + """check hash""" hash_elt = self.get_hash_elt(phase2_elt, hash_name) if phase2.get(param_name): if hash_elt is None: - self.fail('Hash algorithm named {0} not found'.format(hash_name)) + self.fail("Hash algorithm named {0} not found".format(hash_name)) else: if hash_elt is not None: - self.fail('Hash algorithm named {0} found'.format(hash_name)) + self.fail("Hash algorithm named {0} found".format(hash_name)) def param_to_address(self, address): - """ hardcoded addresses """ + """hardcoded addresses""" ret = dict() - if address in ['1.2.3.1', '1.2.3.2']: - ret['type'] = 'address' - ret['address'] = address - ret['type'] = 'address' - ret['address'] = address - elif address == '1.2.3.4/24': - ret['type'] = 'network' - ret['address'] = '1.2.3.4' - ret['netbits'] = '24' - elif address == '10.20.30.40/24': - ret['type'] = 'network' - ret['address'] = '10.20.30.40' - ret['netbits'] = '24' - elif address == '10.20.30.50/24': - ret['type'] = 'network' - ret['address'] = '10.20.30.50' - ret['netbits'] = '24' - elif address in ['lan_100', 'lan']: - ret['type'] = self.unalias_interface(address) + if address in ["1.2.3.1", "1.2.3.2"]: + ret["type"] = "address" + ret["address"] = address + ret["type"] = "address" + ret["address"] = address + elif address == "1.2.3.4/24": + ret["type"] = "network" + ret["address"] = "1.2.3.4" + ret["netbits"] = "24" + elif address == "10.20.30.40/24": + ret["type"] = "network" + ret["address"] = "10.20.30.40" + ret["netbits"] = "24" + elif address == "10.20.30.50/24": + ret["type"] = "network" + ret["address"] = "10.20.30.50" + ret["netbits"] = "24" + elif address in ["lan_100", "lan"]: + ret["type"] = self.unalias_interface(address) else: - self.fail('Please add address {0} to param_to_address'.format(address)) + self.fail("Please add address {0} to param_to_address".format(address)) return ret def check_address(self, phase2, phase2_elt, elt_name, param_name): - """ check address """ + """check address""" if phase2.get(param_name) is None: if phase2_elt.find(elt_name) is not None: - self.fail('Address type {0} found'.format(elt_name)) + self.fail("Address type {0} found".format(elt_name)) else: addr_elt = phase2_elt.find(elt_name) if addr_elt is None: - self.fail('Address type {0} not found'.format(elt_name)) + self.fail("Address type {0} not found".format(elt_name)) address = self.param_to_address(phase2[param_name]) for param in address.keys(): elt = addr_elt.find(param) if elt is None: - self.fail('Address param {0} not found'.format(param)) + self.fail("Address param {0} not found".format(param)) self.assertEqual(elt.text, address[param]) params = address.keys() for elt in addr_elt: if elt.tag not in params: - self.fail('Address param{0} found'.format(elt.tag)) + self.fail("Address param{0} found".format(elt.tag)) def check_target_elt(self, obj, target_elt): - """ test the xml definition of phase2 elt """ + """test the xml definition of phase2 elt""" # bools - if obj.get('disabled'): - self.assert_xml_elt_is_none_or_empty(target_elt, 'disabled') + if obj.get("disabled"): + self.assert_xml_elt_is_none_or_empty(target_elt, "disabled") else: - self.assert_not_find_xml_elt(target_elt, 'disabled') + self.assert_not_find_xml_elt(target_elt, "disabled") - self.assert_xml_elt_equal(target_elt, 'mode', obj['mode']) - if obj.get('procotol') is not None: - self.assert_xml_elt_equal(target_elt, 'protocol', obj['protocol']) + self.assert_xml_elt_equal(target_elt, "mode", obj["mode"]) + if obj.get("procotol") is not None: + self.assert_xml_elt_equal(target_elt, "protocol", obj["protocol"]) else: - self.assert_xml_elt_equal(target_elt, 'protocol', 'esp') - if obj.get('pfsgroup') is not None: - self.assert_xml_elt_equal(target_elt, 'pfsgroup', obj['pfsgroup']) + self.assert_xml_elt_equal(target_elt, "protocol", "esp") + if obj.get("pfsgroup") is not None: + self.assert_xml_elt_equal(target_elt, "pfsgroup", obj["pfsgroup"]) else: - self.assert_xml_elt_equal(target_elt, 'pfsgroup', '14') + self.assert_xml_elt_equal(target_elt, "pfsgroup", "14") - if obj.get('lifetime') is not None: - if obj['lifetime'] == 0: - self.assert_xml_elt_is_none_or_empty(target_elt, 'lifetime') + if obj.get("lifetime") is not None: + if obj["lifetime"] == 0: + self.assert_xml_elt_is_none_or_empty(target_elt, "lifetime") else: - self.assert_xml_elt_equal(target_elt, 'lifetime', str(obj['lifetime'])) + self.assert_xml_elt_equal(target_elt, "lifetime", str(obj["lifetime"])) else: - self.assert_xml_elt_equal(target_elt, 'lifetime', '3600') + self.assert_xml_elt_equal(target_elt, "lifetime", "3600") - if obj.get('pinghost') is not None: - self.assert_xml_elt_equal(target_elt, 'pinghost', str(obj['pinghost'])) + if obj.get("pinghost") is not None: + self.assert_xml_elt_equal(target_elt, "pinghost", str(obj["pinghost"])) else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'pinghost') + self.assert_xml_elt_is_none_or_empty(target_elt, "pinghost") # encryptions - self.check_enc(obj, target_elt, 'aes', 'aes') - self.check_enc(obj, target_elt, 'aes128gcm', 'aes128gcm') - self.check_enc(obj, target_elt, 'aes192gcm', 'aes192gcm') - self.check_enc(obj, target_elt, 'aes256gcm', 'aes256gcm') - self.check_enc(obj, target_elt, 'blowfish', 'blowfish') - self.check_enc(obj, target_elt, '3des', 'des') - self.check_enc(obj, target_elt, 'cast128', 'cast128') + self.check_enc(obj, target_elt, "aes", "aes") + self.check_enc(obj, target_elt, "aes128gcm", "aes128gcm") + self.check_enc(obj, target_elt, "aes192gcm", "aes192gcm") + self.check_enc(obj, target_elt, "aes256gcm", "aes256gcm") + self.check_enc(obj, target_elt, "blowfish", "blowfish") + self.check_enc(obj, target_elt, "3des", "des") + self.check_enc(obj, target_elt, "cast128", "cast128") # hashes - self.check_hash(obj, target_elt, 'hmac_sha1', 'sha1') - self.check_hash(obj, target_elt, 'hmac_sha256', 'sha256') - self.check_hash(obj, target_elt, 'hmac_sha384', 'sha384') - self.check_hash(obj, target_elt, 'hmac_sha512', 'sha512') - self.check_hash(obj, target_elt, 'aesxcbc', 'aesxcbc') + self.check_hash(obj, target_elt, "hmac_sha1", "sha1") + self.check_hash(obj, target_elt, "hmac_sha256", "sha256") + self.check_hash(obj, target_elt, "hmac_sha384", "sha384") + self.check_hash(obj, target_elt, "hmac_sha512", "sha512") + self.check_hash(obj, target_elt, "aesxcbc", "aesxcbc") - self.check_address(obj, target_elt, 'localid', 'local') - self.check_address(obj, target_elt, 'remoteid', 'remote') - self.check_address(obj, target_elt, 'natlocalid', 'nat') + self.check_address(obj, target_elt, "localid", "local") + self.check_address(obj, target_elt, "remoteid", "remote") + self.check_address(obj, target_elt, "natlocalid", "nat") ############## # tests # def test_phase2_create_vti(self): - """ test creation of a new phase2 in vti mode """ - phase2 = dict(p1_descr='test_tunnel', descr='test_p2', mode='vti', local='1.2.3.1', remote='1.2.3.2', aes='True', aes_len='auto', sha256='True') + """test creation of a new phase2 in vti mode""" + phase2 = dict( + p1_descr="test_tunnel", + descr="test_p2", + mode="vti", + local="1.2.3.1", + remote="1.2.3.2", + aes="True", + aes_len="auto", + sha256="True", + ) command = ( "create ipsec_p2 'test_p2' on 'test_tunnel', disabled=False, mode='vti', local='1.2.3.1', remote='1.2.3.2', " "aes=True, aes_len='auto', sha256=True, pfsgroup='14', lifetime=3600" @@ -199,8 +212,17 @@ def test_phase2_create_vti(self): self.do_module_test(phase2, command=command) def test_phase2_create_tunnel(self): - """ test creation of a new phase2 in tunnel mode """ - phase2 = dict(p1_descr='test_tunnel', descr='test_p2', mode='tunnel', local='lan_100', remote='1.2.3.4/24', aes='True', aes_len='auto', sha256='True') + """test creation of a new phase2 in tunnel mode""" + phase2 = dict( + p1_descr="test_tunnel", + descr="test_p2", + mode="tunnel", + local="lan_100", + remote="1.2.3.4/24", + aes="True", + aes_len="auto", + sha256="True", + ) command = ( "create ipsec_p2 'test_p2' on 'test_tunnel', disabled=False, mode='tunnel', local='lan_100', remote='1.2.3.4/24', " "aes=True, aes_len='auto', sha256=True, pfsgroup='14', lifetime=3600" @@ -208,170 +230,360 @@ def test_phase2_create_tunnel(self): self.do_module_test(phase2, command=command) def test_phase2_delete(self): - """ test deletion of a phase2 """ - phase2 = dict(p1_descr='test_tunnel', descr='one_p2', state='absent') + """test deletion of a phase2""" + phase2 = dict(p1_descr="test_tunnel", descr="one_p2", state="absent") command = "delete ipsec_p2 'one_p2' on 'test_tunnel'" self.do_module_test(phase2, delete=True, command=command) def test_phase2_update_noop(self): - """ test not updating a phase2 """ + """test not updating a phase2""" phase2 = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='10.20.30.40/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="10.20.30.40/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) self.do_module_test(phase2, changed=False) def test_phase2_update_aes_len(self): - """ test update aes """ + """test update aes""" phase2 = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='10.20.30.40/24', - aes='True', aes_len='auto', aes128gcm=True, aes128gcm_len='128', sha256='True') + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="10.20.30.40/24", + aes="True", + aes_len="auto", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) command = "update ipsec_p2 'one_p2' on 'test_tunnel' set aes_len='auto'" self.do_module_test(phase2, command=command) def test_phase2_update_disable_aes(self): - """ test removing aes """ + """test removing aes""" phase2 = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='10.20.30.40/24', - aes128gcm=True, aes128gcm_len='128', sha256='True') - command = "update ipsec_p2 'one_p2' on 'test_tunnel' set aes=False, aes_len=none" + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="10.20.30.40/24", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + command = ( + "update ipsec_p2 'one_p2' on 'test_tunnel' set aes=False, aes_len=none" + ) self.do_module_test(phase2, command=command) def test_phase2_update_set_3des(self): - """ test enabling 3des """ + """test enabling 3des""" phase2 = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='10.20.30.40/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', des=True, sha256='True') + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="10.20.30.40/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + des=True, + sha256="True", + ) command = "update ipsec_p2 'one_p2' on 'test_tunnel' set des=True" self.do_module_test(phase2, command=command) def test_phase2_update_remove_3des(self): - """ test disabling 3des """ + """test disabling 3des""" phase2 = dict( - p1_descr='test_tunnel', descr='another_p2', mode='tunnel', local='lan', remote='10.20.30.50/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', des=False, sha256='True') + p1_descr="test_tunnel", + descr="another_p2", + mode="tunnel", + local="lan", + remote="10.20.30.50/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + des=False, + sha256="True", + ) command = "update ipsec_p2 'another_p2' on 'test_tunnel' set des=False" self.do_module_test(phase2, command=command) def test_phase2_update_remove_sha256(self): - """ test disabling sha256 """ + """test disabling sha256""" phase2 = dict( - p1_descr='test_tunnel', descr='another_p2', mode='tunnel', local='lan', remote='10.20.30.50/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', des=True, sha512='True') + p1_descr="test_tunnel", + descr="another_p2", + mode="tunnel", + local="lan", + remote="10.20.30.50/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + des=True, + sha512="True", + ) command = "update ipsec_p2 'another_p2' on 'test_tunnel' set sha256=False, sha512=True" self.do_module_test(phase2, command=command) def test_phase2_update_change_address(self): - """ test changing address """ + """test changing address""" phase2 = dict( - p1_descr='test_tunnel', descr='third_p2', mode='tunnel', local='lan_100', remote='10.20.30.50/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', des=True, sha256='True') + p1_descr="test_tunnel", + descr="third_p2", + mode="tunnel", + local="lan_100", + remote="10.20.30.50/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + des=True, + sha256="True", + ) command = "update ipsec_p2 'third_p2' on 'test_tunnel' set local='lan_100'" self.do_module_test(phase2, command=command) def test_phase2_update_set_nat(self): - """ test setting nat """ + """test setting nat""" phase2 = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='10.20.30.40/24', nat='1.2.3.4/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="10.20.30.40/24", + nat="1.2.3.4/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) command = "update ipsec_p2 'one_p2' on 'test_tunnel' set nat='1.2.3.4/24'" self.do_module_test(phase2, command=command) def test_phase2_update_remove_nat(self): - """ test removing nat """ + """test removing nat""" phase2 = dict( - p1_descr='test_tunnel', descr='nat_p2', mode='tunnel', local='lan', remote='1.2.3.4/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') + p1_descr="test_tunnel", + descr="nat_p2", + mode="tunnel", + local="lan", + remote="1.2.3.4/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) command = "update ipsec_p2 'nat_p2' on 'test_tunnel' set nat=none" self.do_module_test(phase2, command=command) def test_phase2_inexistent_tunnel(self): - """ test error with inexistent tunnel """ + """test error with inexistent tunnel""" ipsec = dict( - p1_descr='inexistent_tunnel', descr='nat_p2', mode='tunnel', local='lan', remote='1.2.3.4/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'No ipsec tunnel named inexistent_tunnel' + p1_descr="inexistent_tunnel", + descr="nat_p2", + mode="tunnel", + local="lan", + remote="1.2.3.4/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = "No ipsec tunnel named inexistent_tunnel" self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_no_encryption(self): - """ test error with no encryption """ + """test error with no encryption""" ipsec = dict( - p1_descr='test_tunnel', descr='nat_p2', mode='tunnel', local='lan', remote='1.2.3.4/24', sha256='True') - msg = 'At least one encryption algorithm must be selected.' + p1_descr="test_tunnel", + descr="nat_p2", + mode="tunnel", + local="lan", + remote="1.2.3.4/24", + sha256="True", + ) + msg = "At least one encryption algorithm must be selected." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_no_hash(self): - """ test error with no hash """ + """test error with no hash""" ipsec = dict( - p1_descr='test_tunnel', descr='nat_p2', mode='tunnel', local='lan', remote='1.2.3.4/24', cast128='True') - msg = 'At least one hashing algorithm needs to be selected.' + p1_descr="test_tunnel", + descr="nat_p2", + mode="tunnel", + local="lan", + remote="1.2.3.4/24", + cast128="True", + ) + msg = "At least one hashing algorithm needs to be selected." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_vti_lan(self): - """ test error on vti address """ + """test error on vti address""" ipsec = dict( - p1_descr='test_tunnel', descr='nat_p2', mode='vti', local='lan', remote='1.2.3.4', cast128='True', sha256='True') - msg = 'VTI requires a valid local network or IP address for its endpoint address.' + p1_descr="test_tunnel", + descr="nat_p2", + mode="vti", + local="lan", + remote="1.2.3.4", + cast128="True", + sha256="True", + ) + msg = ( + "VTI requires a valid local network or IP address for its endpoint address." + ) self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_vti_lan2(self): - """ test error on vti address """ + """test error on vti address""" ipsec = dict( - p1_descr='test_tunnel', descr='nat_p2', mode='vti', local='1.2.3.4', remote='lan', cast128='True', sha256='True') - msg = 'VTI requires a valid remote IP address for its endpoint address.' + p1_descr="test_tunnel", + descr="nat_p2", + mode="vti", + local="1.2.3.4", + remote="lan", + cast128="True", + sha256="True", + ) + msg = "VTI requires a valid remote IP address for its endpoint address." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_tunnel6_remote(self): - """ test error on tunnel6 address """ + """test error on tunnel6 address""" ipsec = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel6', local='lan', remote='10.20.30.40/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'A valid IPv6 address or network must be specified in remote with tunnel6.' + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel6", + local="lan", + remote="10.20.30.40/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = ( + "A valid IPv6 address or network must be specified in remote with tunnel6." + ) self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_tunnel6_remote2(self): - """ test error on tunnel6 address """ + """test error on tunnel6 address""" ipsec = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel6', local='lan', remote='1.2.3.4', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'A valid IPv6 address or network must be specified in remote with tunnel6.' + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel6", + local="lan", + remote="1.2.3.4", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = ( + "A valid IPv6 address or network must be specified in remote with tunnel6." + ) self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_tunnel6_local(self): - """ test error on tunnel6 address """ + """test error on tunnel6 address""" ipsec = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel6', local='1.2.3.4/24', remote='10.20.30.40/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'A valid IPv6 address or network must be specified in local with tunnel6.' + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel6", + local="1.2.3.4/24", + remote="10.20.30.40/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = "A valid IPv6 address or network must be specified in local with tunnel6." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_tunnel_remote(self): - """ test error on tunnel address """ + """test error on tunnel address""" ipsec = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='fd69:81a5:a5:7396:0:0:0:0', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'A valid IPv4 address or network must be specified in remote with tunnel.' + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="fd69:81a5:a5:7396:0:0:0:0", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = "A valid IPv4 address or network must be specified in remote with tunnel." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_tunnel_remote2(self): - """ test error on tunnel address """ + """test error on tunnel address""" ipsec = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='lan', remote='fd69:81a5:a5:7396:0:0:0:0/64', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'A valid IPv4 address or network must be specified in remote with tunnel.' + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="lan", + remote="fd69:81a5:a5:7396:0:0:0:0/64", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = "A valid IPv4 address or network must be specified in remote with tunnel." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_tunnel_local(self): - """ test error on tunnel address """ + """test error on tunnel address""" ipsec = dict( - p1_descr='test_tunnel', descr='one_p2', mode='tunnel', local='fd69:81a5:a5:7396:0:0:0:0', remote='10.20.30.40/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'A valid IPv4 address or network must be specified in local with tunnel.' + p1_descr="test_tunnel", + descr="one_p2", + mode="tunnel", + local="fd69:81a5:a5:7396:0:0:0:0", + remote="10.20.30.40/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = "A valid IPv4 address or network must be specified in local with tunnel." self.do_module_test(ipsec, msg=msg, failed=True) def test_phase2_duplicate(self): - """ test error duplicate local/remote definition """ + """test error duplicate local/remote definition""" phase2 = dict( - p1_descr='test_tunnel', descr='duplicate_p2', mode='tunnel', local='lan', remote='10.20.30.40/24', - aes='True', aes_len='128', aes128gcm=True, aes128gcm_len='128', sha256='True') - msg = 'Phase2 with this Local/Remote networks combination is already defined for this Phase1.' + p1_descr="test_tunnel", + descr="duplicate_p2", + mode="tunnel", + local="lan", + remote="10.20.30.40/24", + aes="True", + aes_len="128", + aes128gcm=True, + aes128gcm_len="128", + sha256="True", + ) + msg = "Phase2 with this Local/Remote networks combination is already defined for this Phase1." self.do_module_test(phase2, msg=msg, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_ipsec_proposal.py b/tests/unit/plugins/modules/test_pfsense_ipsec_proposal.py index ca183654..87535a76 100644 --- a/tests/unit/plugins/modules/test_pfsense_ipsec_proposal.py +++ b/tests/unit/plugins/modules/test_pfsense_ipsec_proposal.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,13 +12,14 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_ipsec_proposal -from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import PFSenseIpsecProposalModule +from ansible_collections.pfsensible.core.plugins.module_utils.ipsec_proposal import ( + PFSenseIpsecProposalModule, +) from .pfsense_module import TestPFSenseModule from parameterized import parameterized class TestPFSenseIpsecProposalModule(TestPFSenseModule): - module = pfsense_ipsec_proposal def __init__(self, *args, **kwargs): @@ -25,64 +27,69 @@ def __init__(self, *args, **kwargs): self.pfmodule = PFSenseIpsecProposalModule def get_config_file(self): - """ get config file """ + """get config file""" if self.get_version.return_value.startswith("2.4."): - return '2.4/pfsense_ipsec_proposal_config.xml' + return "2.4/pfsense_ipsec_proposal_config.xml" - return 'pfsense_ipsec_proposal_config.xml' + return "pfsense_ipsec_proposal_config.xml" ############## # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated proposal xml definition """ + """get the generated proposal xml definition""" elt_filter = {} - elt_filter['descr'] = obj['descr'] + elt_filter["descr"] = obj["descr"] - ipsec_elt = self.assert_has_xml_tag('ipsec', elt_filter) + ipsec_elt = self.assert_has_xml_tag("ipsec", elt_filter) if ipsec_elt is None: return None - encryption_elt = ipsec_elt.find('encryption') + encryption_elt = ipsec_elt.find("encryption") if encryption_elt is None: return None for item_elt in encryption_elt: - elt = item_elt.find('dhgroup') - if elt is None or elt.text != str(obj['dhgroup']): + elt = item_elt.find("dhgroup") + if elt is None or elt.text != str(obj["dhgroup"]): continue - elt = item_elt.find('hash-algorithm') - if elt is None or elt.text != obj['hash']: + elt = item_elt.find("hash-algorithm") + if elt is None or elt.text != obj["hash"]: continue if not self.get_version.return_value.startswith("2.4."): - elt = item_elt.find('prf-algorithm') - if elt is None or 'prf' not in obj and elt.text != 'sha256' and elt.text != obj['prf']: + elt = item_elt.find("prf-algorithm") + if ( + elt is None + or "prf" not in obj + and elt.text != "sha256" + and elt.text != obj["prf"] + ): continue - encalg_elt = item_elt.find('encryption-algorithm') + encalg_elt = item_elt.find("encryption-algorithm") if encalg_elt is None: continue - elt = encalg_elt.find('name') - if elt is None or elt.text != obj['encryption']: + elt = encalg_elt.find("name") + if elt is None or elt.text != obj["encryption"]: continue - elt = encalg_elt.find('keylen') - if (elt is None or elt.text == '') and obj.get('key_length') is None: + elt = encalg_elt.find("keylen") + if (elt is None or elt.text == "") and obj.get("key_length") is None: return item_elt - if elt is not None and elt.text == str(obj.get('key_length')): + if elt is not None and elt.text == str(obj.get("key_length")): return item_elt return None def check_target_elt(self, obj, target_elt): - """ test the xml definition of proposal elt """ + """test the xml definition of proposal elt""" if target_elt is None: - self.fail('Unable to find proposal on ' + obj['descr']) + self.fail("Unable to find proposal on " + obj["descr"]) def strip_commands(self, commands): - """ remove old or new parameters """ + """remove old or new parameters""" if self.get_version.return_value.startswith("2.4."): commands = commands.replace(", prf='sha256'", "") return commands @@ -92,49 +99,88 @@ def strip_commands(self, commands): # @parameterized.expand([["2.4.4"], ["2.5.0"], ["2.5.2"]]) def test_ipsec_proposal_create(self, pfsense_version): - """ test creation of a new proposal """ + """test creation of a new proposal""" self.get_version.return_value = pfsense_version - proposal = dict(descr='test_tunnel', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup=21) + proposal = dict( + descr="test_tunnel", + encryption="aes128gcm", + key_length=128, + hash="sha256", + dhgroup=21, + ) command = "create ipsec_proposal 'test_tunnel', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup='21', prf='sha256'" self.do_module_test(proposal, command=command) @parameterized.expand([["2.4.4"], ["2.5.0"], ["2.5.2"]]) def test_ipsec_proposal_create_nokeylen(self, pfsense_version): - """ test creation of a new proposal """ + """test creation of a new proposal""" self.get_version.return_value = pfsense_version - proposal = dict(descr='test_tunnel2', encryption='cast128', hash='sha256', dhgroup=21) + proposal = dict( + descr="test_tunnel2", encryption="cast128", hash="sha256", dhgroup=21 + ) command = "create ipsec_proposal 'test_tunnel2', encryption='cast128', hash='sha256', dhgroup='21', prf='sha256'" self.do_module_test(proposal, command=command) @parameterized.expand([["2.4.4"], ["2.5.0"], ["2.5.2"]]) def test_ipsec_proposal_delete(self, pfsense_version): - """ test deletion of an ipsec proposal """ + """test deletion of an ipsec proposal""" self.get_version.return_value = pfsense_version - proposal = dict(descr='test_tunnel', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup=14, state='absent') + proposal = dict( + descr="test_tunnel", + encryption="aes128gcm", + key_length=128, + hash="sha256", + dhgroup=14, + state="absent", + ) command = "delete ipsec_proposal 'test_tunnel', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup='14', prf='sha256'" self.do_module_test(proposal, delete=True, command=command) @parameterized.expand([["2.4.4"], ["2.5.0"], ["2.5.2"]]) def test_ipsec_proposal_update_noop(self, pfsense_version): - """ test not updating a ipsec proposal """ + """test not updating a ipsec proposal""" self.get_version.return_value = pfsense_version - proposal = dict(descr='test_tunnel', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup=14) + proposal = dict( + descr="test_tunnel", + encryption="aes128gcm", + key_length=128, + hash="sha256", + dhgroup=14, + ) self.do_module_test(proposal, changed=False) def test_ipsec_proposal_wrong_keylen(self): - """ test using a wrong key_length """ - proposal = dict(descr='test_tunnel', encryption='aes128gcm', key_length=256, hash='sha256', dhgroup=14) - msg = 'key_length for encryption aes128gcm must be one of: 64, 96, 128.' + """test using a wrong key_length""" + proposal = dict( + descr="test_tunnel", + encryption="aes128gcm", + key_length=256, + hash="sha256", + dhgroup=14, + ) + msg = "key_length for encryption aes128gcm must be one of: 64, 96, 128." self.do_module_test(proposal, msg=msg, failed=True) def test_ipsec_proposal_wrong_tunnel(self): - """ test using a wrong tunnel """ - proposal = dict(descr='test_tunnel3', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup=14) - msg = 'No ipsec tunnel named test_tunnel3' + """test using a wrong tunnel""" + proposal = dict( + descr="test_tunnel3", + encryption="aes128gcm", + key_length=128, + hash="sha256", + dhgroup=14, + ) + msg = "No ipsec tunnel named test_tunnel3" self.do_module_test(proposal, msg=msg, failed=True) def test_ipsec_proposal_wrong_encryption(self): - """ test using a wrong encryption """ - proposal = dict(descr='test_tunnel2', encryption='aes128gcm', key_length=128, hash='sha256', dhgroup=14) - msg = 'Encryption Algorithm AES-GCM can only be used with IKEv2' + """test using a wrong encryption""" + proposal = dict( + descr="test_tunnel2", + encryption="aes128gcm", + key_length=128, + hash="sha256", + dhgroup=14, + ) + msg = "Encryption Algorithm AES-GCM can only be used with IKEv2" self.do_module_test(proposal, msg=msg, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_log_settings.py b/tests/unit/plugins/modules/test_pfsense_log_settings.py index bfc11823..e17155ad 100644 --- a/tests/unit/plugins/modules/test_pfsense_log_settings.py +++ b/tests/unit/plugins/modules/test_pfsense_log_settings.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -16,30 +17,30 @@ class TestPFSenseLogSettingsModule(TestPFSenseModule): - module = pfsense_log_settings def __init__(self, *args, **kwargs): super(TestPFSenseLogSettingsModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_syslog_config.xml' + self.config_file = "pfsense_syslog_config.xml" self.pfmodule = pfsense_log_settings.PFSenseLogSettingsModule self.defaults = { - 'filterdescriptions': 1, - 'reverse': True, - 'nentries': 50, - 'sourceip': None, - 'ipproto': 'ipv4', + "filterdescriptions": 1, + "reverse": True, + "nentries": 50, + "sourceip": None, + "ipproto": "ipv4", } ############## # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - return self.assert_find_xml_elt(self.xml_result, 'syslog') + """get the generated xml definition""" + return self.assert_find_xml_elt(self.xml_result, "syslog") def check_target_elt(self, obj, target_elt): - """ test the xml definition of target elt """ + """test the xml definition of target elt""" + def check_param(param, xml_field=None): if obj is not None: if xml_field is None: @@ -49,20 +50,26 @@ def check_param(param, xml_field=None): # Special handling for sourceip # Given as ip or descr but set as internal interface id interface_map = { - '192.168.240.137': 'wan', - 'wan': 'wan', - '192.168.1.242': 'lan', - '10.255.2.254': '_vip5c0a4b6139b05', - '127.0.0.1': 'lo0', - 'Localhost': 'lo0', + "192.168.240.137": "wan", + "wan": "wan", + "192.168.1.242": "lan", + "10.255.2.254": "_vip5c0a4b6139b05", + "127.0.0.1": "lo0", + "Localhost": "lo0", } - if param == 'sourceip': - self.assert_xml_elt_equal(target_elt, xml_field, interface_map.get(obj[param], obj[param])) + if param == "sourceip": + self.assert_xml_elt_equal( + target_elt, + xml_field, + interface_map.get(obj[param], obj[param]), + ) else: self.assert_xml_elt_equal(target_elt, xml_field, obj[param]) else: if param in self.defaults: - self.assert_xml_elt_equal(target_elt, xml_field, self.defaults[param]) + self.assert_xml_elt_equal( + target_elt, xml_field, self.defaults[param] + ) else: self.assert_not_find_xml_elt(target_elt, xml_field) @@ -74,13 +81,15 @@ def check_bool_param(param, xml_field=None): if param in obj: # Special handling for inverted field # When nologdefaultpass is present in xml, value is False - if param == 'nologdefaultpass': + if param == "nologdefaultpass": if obj[param]: self.assert_not_find_xml_elt(target_elt, param) else: - self.assert_xml_elt_equal(target_elt, xml_field, '') + self.assert_xml_elt_equal(target_elt, xml_field, "") else: - self.check_param_bool(obj, target_elt, param, xml_field=xml_field) + self.check_param_bool( + obj, target_elt, param, xml_field=xml_field + ) else: if param in self.defaults: if self.defaults[param]: @@ -90,717 +99,748 @@ def check_bool_param(param, xml_field=None): else: self.assert_not_find_xml_elt(target_elt, xml_field) - check_param('logformat', xml_field='format') - check_bool_param('reverse') - check_param('nentries') - check_bool_param('nologdefaultblock') - check_bool_param('nologdefaultpass') - check_bool_param('nologbogons') - check_bool_param('nologprivatenets') - check_bool_param('nolognginx') - check_bool_param('rawfilter') - check_param('filterdescriptions') - check_bool_param('disablelocallogging') - check_param('logfilesize') - check_param('logcompressiontype') - check_param('rotatecount') - check_bool_param('enable') - check_param('sourceip') - check_param('ipproto') - check_param('remoteserver') - check_param('remoteserver2') - check_param('remoteserver3') - check_bool_param('logall') - check_bool_param('system') - check_bool_param('logfilter', xml_field='filter') - check_bool_param('resolver') - check_bool_param('dhcp') - check_bool_param('ppp') - check_bool_param('auth') - check_bool_param('portalauth') - check_bool_param('vpn') - check_bool_param('dpinger') - check_bool_param('routing') - check_bool_param('ntpd') - check_bool_param('hostapd') + check_param("logformat", xml_field="format") + check_bool_param("reverse") + check_param("nentries") + check_bool_param("nologdefaultblock") + check_bool_param("nologdefaultpass") + check_bool_param("nologbogons") + check_bool_param("nologprivatenets") + check_bool_param("nolognginx") + check_bool_param("rawfilter") + check_param("filterdescriptions") + check_bool_param("disablelocallogging") + check_param("logfilesize") + check_param("logcompressiontype") + check_param("rotatecount") + check_bool_param("enable") + check_param("sourceip") + check_param("ipproto") + check_param("remoteserver") + check_param("remoteserver2") + check_param("remoteserver3") + check_bool_param("logall") + check_bool_param("system") + check_bool_param("logfilter", xml_field="filter") + check_bool_param("resolver") + check_bool_param("dhcp") + check_bool_param("ppp") + check_bool_param("auth") + check_bool_param("portalauth") + check_bool_param("vpn") + check_bool_param("dpinger") + check_bool_param("routing") + check_bool_param("ntpd") + check_bool_param("hostapd") def test_syslog_logformat_rfc5424(self): - """ test syslog format rfc5424 """ - syslog = dict(logformat='rfc5424') + """test syslog format rfc5424""" + syslog = dict(logformat="rfc5424") command = "update log_settings syslog set format='rfc5424'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logformat_rfc3164(self): - """ test syslog format rfc3164 """ - syslog = dict(logformat='rfc3164') + """test syslog format rfc3164""" + syslog = dict(logformat="rfc3164") command = "update log_settings syslog set format='rfc3164'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logformat_invalid(self): - """ test syslog format invalid """ - syslog = dict(logformat='rfc1149') - msg = 'value of logformat must be one of: rfc3164, rfc5424, got: rfc1149' + """test syslog format invalid""" + syslog = dict(logformat="rfc1149") + msg = "value of logformat must be one of: rfc3164, rfc5424, got: rfc1149" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_reverse(self): - """ test log_settings reverse=False """ + """test log_settings reverse=False""" syslog = dict(reverse=False) command = "update log_settings syslog set reverse=False" self.do_module_test(syslog, command=command, state=None) def test_syslog_reverse_true(self): - """ test log_settings reverse=True """ + """test log_settings reverse=True""" syslog = dict(reverse=True) self.do_module_test(syslog, changed=False, state=None) def test_syslog_nentries_valid(self): - """ test log_settings nentries """ - syslog = dict(nentries='5') + """test log_settings nentries""" + syslog = dict(nentries="5") command = "update log_settings syslog set nentries='5'" self.do_module_test(syslog, command=command, state=None) def test_syslog_nentries_valid2(self): - """ test log_settings nentries """ - syslog = dict(nentries='500') + """test log_settings nentries""" + syslog = dict(nentries="500") command = "update log_settings syslog set nentries='500'" self.do_module_test(syslog, command=command, state=None) def test_syslog_nentries_valid3(self): - """ test log_settings nentries """ - syslog = dict(nentries='200000') + """test log_settings nentries""" + syslog = dict(nentries="200000") command = "update log_settings syslog set nentries='200000'" self.do_module_test(syslog, command=command, state=None) def test_syslog_nentries_invalid1(self): - """ test log_settings nentries """ - syslog = dict(nentries='-1') - msg = 'nentries must be an integer from 5 to 200000' + """test log_settings nentries""" + syslog = dict(nentries="-1") + msg = "nentries must be an integer from 5 to 200000" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_nentries_invalid2(self): - """ test log_settings nentries """ - syslog = dict(nentries='4') - msg = 'nentries must be an integer from 5 to 200000' + """test log_settings nentries""" + syslog = dict(nentries="4") + msg = "nentries must be an integer from 5 to 200000" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_nentries_invalid3(self): - """ test log_settings nentries """ - syslog = dict(nentries='200001') - msg = 'nentries must be an integer from 5 to 200000' + """test log_settings nentries""" + syslog = dict(nentries="200001") + msg = "nentries must be an integer from 5 to 200000" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_nologdefaultblock_false(self): - """ test log_settings nologdefaultblock=False """ + """test log_settings nologdefaultblock=False""" syslog = dict(nologdefaultblock=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_nologdefaultblock_true(self): - """ test log_settings nologdefaultblock=True """ + """test log_settings nologdefaultblock=True""" syslog = dict(nologdefaultblock=True) command = "update log_settings syslog set nologdefaultblock=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_nologdefaultpass_false(self): - """ test log_settings nologdefaultpass=False """ + """test log_settings nologdefaultpass=False""" syslog = dict(nologdefaultpass=False) # different bool values are correct, logic is inverted command = "update log_settings syslog set nologdefaultpass=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_nologdefaultpass_true(self): - """ test log_settings nologdefaultpass=True """ + """test log_settings nologdefaultpass=True""" syslog = dict(nologdefaultpass=True) self.do_module_test(syslog, changed=False, state=None) def test_syslog_nologbogons_false(self): - """ test log_settings nologbogons=False """ + """test log_settings nologbogons=False""" syslog = dict(nologbogons=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_nologbogons_true(self): - """ test log_settings nologbogons=True """ + """test log_settings nologbogons=True""" syslog = dict(nologbogons=True) command = "update log_settings syslog set nologbogons=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_nologprivatenets_false(self): - """ test log_settings nologprivatenets=False """ + """test log_settings nologprivatenets=False""" syslog = dict(nologprivatenets=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_nologprivatenets_true(self): - """ test log_settings nologprivatenets=True """ + """test log_settings nologprivatenets=True""" syslog = dict(nologprivatenets=True) command = "update log_settings syslog set nologprivatenets=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_nolognginx_false(self): - """ test log_settings nolognginx=False """ + """test log_settings nolognginx=False""" syslog = dict(nolognginx=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_nolognginx_true(self): - """ test log_settings nolognginx=True """ + """test log_settings nolognginx=True""" syslog = dict(nolognginx=True) command = "update log_settings syslog set nolognginx=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_rawfilter_false(self): - """ test log_settings rawfilter=False """ + """test log_settings rawfilter=False""" syslog = dict(rawfilter=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_rawfilter_true(self): - """ test log_settings rawfilter=True """ + """test log_settings rawfilter=True""" syslog = dict(rawfilter=True) command = "update log_settings syslog set rawfilter=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_filterdescriptions_valid0(self): - """ test log_settings filterdescriptions = 0 """ - syslog = dict(filterdescriptions='0') + """test log_settings filterdescriptions = 0""" + syslog = dict(filterdescriptions="0") command = "update log_settings syslog set filterdescriptions='0'" self.do_module_test(syslog, command=command, state=None) def test_syslog_filterdescriptions_valid1(self): - """ test log_settings filterdescriptions = 1 """ - syslog = dict(filterdescriptions='1') + """test log_settings filterdescriptions = 1""" + syslog = dict(filterdescriptions="1") self.do_module_test(syslog, changed=False, state=None) def test_syslog_filterdescriptions_valid2(self): - """ test log_settings filterdescriptions = 2 """ - syslog = dict(filterdescriptions='2') + """test log_settings filterdescriptions = 2""" + syslog = dict(filterdescriptions="2") command = "update log_settings syslog set filterdescriptions='2'" self.do_module_test(syslog, command=command, state=None) def test_syslog_filterdescriptions_invalid3(self): - """ test log_settings filterdescriptions = 3 """ - syslog = dict(filterdescriptions='3') + """test log_settings filterdescriptions = 3""" + syslog = dict(filterdescriptions="3") msg = "value of filterdescriptions must be one of: 0, 1, 2, got: 3" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_disablelocallogging_false(self): - """ test log_settings disablelocallogging=False """ + """test log_settings disablelocallogging=False""" syslog = dict(disablelocallogging=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_disablelocallogging_true(self): - """ test log_settings disablelocallogging=True """ + """test log_settings disablelocallogging=True""" syslog = dict(disablelocallogging=True) command = "update log_settings syslog set disablelocallogging=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilesize_valid1(self): - """ test log_settings logfilesize """ - syslog = dict(logfilesize='512000') + """test log_settings logfilesize""" + syslog = dict(logfilesize="512000") command = "update log_settings syslog set logfilesize='512000'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilesize_valid2(self): - """ test log_settings logfilesize """ - syslog = dict(logfilesize='100000') + """test log_settings logfilesize""" + syslog = dict(logfilesize="100000") command = "update log_settings syslog set logfilesize='100000'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilesize_valid3(self): - """ test log_settings logfilesize """ + """test log_settings logfilesize""" syslog = dict(logfilesize=int((2**32) / 2) - 1) command = "update log_settings syslog set logfilesize='2147483647'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilesize_invalid1(self): - """ test log_settings logfilesize """ - syslog = dict(logfilesize='-1') - msg = 'logfilesize must be an integer greater or equal than 100000' + """test log_settings logfilesize""" + syslog = dict(logfilesize="-1") + msg = "logfilesize must be an integer greater or equal than 100000" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_logfilesize_invalid2(self): - """ test log_settings logfilesize """ - syslog = dict(logfilesize='99999') - msg = 'logfilesize must be an integer greater or equal than 100000' + """test log_settings logfilesize""" + syslog = dict(logfilesize="99999") + msg = "logfilesize must be an integer greater or equal than 100000" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_logfilesize_invalid3(self): - """ test log_settings logfilesize """ - syslog = dict(logfilesize='0') - msg = 'logfilesize must be an integer greater or equal than 100000' + """test log_settings logfilesize""" + syslog = dict(logfilesize="0") + msg = "logfilesize must be an integer greater or equal than 100000" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_logfilesize_invalid4(self): - """ test log_settings logfilesize """ + """test log_settings logfilesize""" syslog = dict(logfilesize=int(((2**32) / 2) + 1)) - msg = 'logfilesize is too large: 2147483649' + msg = "logfilesize is too large: 2147483649" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_logcompressiontype_valid_xz(self): - """ test syslog logcompression = xz """ - syslog = dict(logcompressiontype='xz') + """test syslog logcompression = xz""" + syslog = dict(logcompressiontype="xz") command = "update log_settings syslog set logcompressiontype='xz'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logcompressiontype_valid_gzip(self): - """ test syslog logcompression = gzip """ - syslog = dict(logcompressiontype='gzip') + """test syslog logcompression = gzip""" + syslog = dict(logcompressiontype="gzip") command = "update log_settings syslog set logcompressiontype='gzip'" self.do_module_test(syslog, command=command, state=None) def test_syslog_rotatecount_valid0(self): - """ test log_settings rotatecount """ - syslog = dict(rotatecount='0') + """test log_settings rotatecount""" + syslog = dict(rotatecount="0") command = "update log_settings syslog set rotatecount='0'" self.do_module_test(syslog, command=command, state=None) def test_syslog_rotatecount_valid1(self): - """ test log_settings rotatecount """ - syslog = dict(rotatecount='7') + """test log_settings rotatecount""" + syslog = dict(rotatecount="7") command = "update log_settings syslog set rotatecount='7'" self.do_module_test(syslog, command=command, state=None) def test_syslog_rotatecount_valid2(self): - """ test log_settings rotatecount """ - syslog = dict(rotatecount='31') + """test log_settings rotatecount""" + syslog = dict(rotatecount="31") command = "update log_settings syslog set rotatecount='31'" self.do_module_test(syslog, command=command, state=None) def test_syslog_rotatecount_valid3(self): - """ test log_settings rotatecount """ - syslog = dict(rotatecount='99') + """test log_settings rotatecount""" + syslog = dict(rotatecount="99") command = "update log_settings syslog set rotatecount='99'" self.do_module_test(syslog, command=command, state=None) def test_syslog_rotatecount_invalid1(self): - """ test log_settings rotatecount """ - syslog = dict(rotatecount='-1') - msg = 'rotatecount must be an integer from 0 to 99' + """test log_settings rotatecount""" + syslog = dict(rotatecount="-1") + msg = "rotatecount must be an integer from 0 to 99" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_rotatecount_invalid2(self): - """ test log_settings rotatecount """ - syslog = dict(rotatecount='100') - msg = 'rotatecount must be an integer from 0 to 99' + """test log_settings rotatecount""" + syslog = dict(rotatecount="100") + msg = "rotatecount must be an integer from 0 to 99" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_enable_true(self): - """ test syslog format enable=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True) + """test syslog format enable=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_enable_false(self): - """ test syslog format logall=false """ + """test syslog format logall=false""" syslog = dict(enable=False) self.do_module_test(syslog, changed=False, state=None) def test_syslog_ipproto_ipv4(self): - """ test syslog ipproto ipv4 """ - syslog = dict(ipproto='ipv4') + """test syslog ipproto ipv4""" + syslog = dict(ipproto="ipv4") command = "update log_settings syslog set ipproto='ipv4'" self.do_module_test(syslog, command=command, state=None, changed=False) def test_syslog_ipproto_ipv6(self): - """ test syslog ipproto ipv6 """ - syslog = dict(ipproto='ipv6') + """test syslog ipproto ipv6""" + syslog = dict(ipproto="ipv6") command = "update log_settings syslog set ipproto='ipv6'" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_wan_ip(self): - """ test log_settings sourceip=wan """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='192.168.240.137') + """test log_settings sourceip=wan""" + syslog = dict( + enable=True, remoteserver="1.2.3.4", logall=True, sourceip="192.168.240.137" + ) command = "update log_settings syslog set enable=True, sourceip='wan', remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_wan_descr(self): - """ test log_settings sourceip=wan """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='wan') + """test log_settings sourceip=wan""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, sourceip="wan") command = "update log_settings syslog set enable=True, sourceip='wan', remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_lan(self): - """ test log_settings sourceip=lan """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='192.168.1.242') + """test log_settings sourceip=lan""" + syslog = dict( + enable=True, remoteserver="1.2.3.4", logall=True, sourceip="192.168.1.242" + ) command = "update log_settings syslog set enable=True, sourceip='lan', remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_lo0(self): - """ test log_settings sourceip=lan """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='127.0.0.1') + """test log_settings sourceip=lan""" + syslog = dict( + enable=True, remoteserver="1.2.3.4", logall=True, sourceip="127.0.0.1" + ) command = "update log_settings syslog set enable=True, sourceip='lo0', remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_descr(self): - """ test log_settings sourceip=lan """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='Localhost') + """test log_settings sourceip=lan""" + syslog = dict( + enable=True, remoteserver="1.2.3.4", logall=True, sourceip="Localhost" + ) command = "update log_settings syslog set enable=True, sourceip='lo0', remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_valid_empty(self): - """ test log_settings sourceip='' """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip=None) + """test log_settings sourceip=''""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, sourceip=None) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_valid_vip_ip(self): - """ test log_settings sourceip=_vip5c0a4b6139b05 """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='10.255.2.254') + """test log_settings sourceip=_vip5c0a4b6139b05""" + syslog = dict( + enable=True, remoteserver="1.2.3.4", logall=True, sourceip="10.255.2.254" + ) command = "update log_settings syslog set enable=True, sourceip='_vip5c0a4b6139b05', remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_sourceip_invalid_vip(self): - """ test log_settings sourceip=_vip5c0a4b6139b06 """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='_vip5c0a4b6139b05') + """test log_settings sourceip=_vip5c0a4b6139b06""" + syslog = dict( + enable=True, + remoteserver="1.2.3.4", + logall=True, + sourceip="_vip5c0a4b6139b05", + ) msg = "sourceip: Invalid address _vip5c0a4b6139b05!" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_sourceip_invalid_opt4(self): - """ test log_settings sourceip=opt4 """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, sourceip='opt4') + """test log_settings sourceip=opt4""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, sourceip="opt4") msg = "sourceip: Invalid address opt4!" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_remoteserver_hostname(self): - """ test log_settings remoteserver_hostname """ - syslog = dict(enable=True, remoteserver='2001:0db8:cafe:affe:0000:0000:0000:0001', logall=True) + """test log_settings remoteserver_hostname""" + syslog = dict( + enable=True, + remoteserver="2001:0db8:cafe:affe:0000:0000:0000:0001", + logall=True, + ) command = "update log_settings syslog set enable=True, remoteserver='2001:0db8:cafe:affe:0000:0000:0000:0001', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_remoteserver_fqdn(self): - """ test log_settings remoteserver_fqdn """ - syslog = dict(enable=True, remoteserver='logserver.example.com', logall=True) + """test log_settings remoteserver_fqdn""" + syslog = dict(enable=True, remoteserver="logserver.example.com", logall=True) command = "update log_settings syslog set enable=True, remoteserver='logserver.example.com', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_remoteserver_fqdn_port(self): - """ test log_settings remoteserver_fqdn_port """ - syslog = dict(enable=True, remoteserver='logserver.example.com:514', logall=True) + """test log_settings remoteserver_fqdn_port""" + syslog = dict( + enable=True, remoteserver="logserver.example.com:514", logall=True + ) command = "update log_settings syslog set enable=True, remoteserver='logserver.example.com:514', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_remoteserver_ipv6(self): - """ test log_settings remoteserver_ipv6 """ - syslog = dict(enable=True, remoteserver='2001:0db8:cafe:affe:0000:0000:0000:0001', logall=True) + """test log_settings remoteserver_ipv6""" + syslog = dict( + enable=True, + remoteserver="2001:0db8:cafe:affe:0000:0000:0000:0001", + logall=True, + ) command = "update log_settings syslog set enable=True, remoteserver='2001:0db8:cafe:affe:0000:0000:0000:0001', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_remoteserver_ipv6_port(self): - """ test log_settings remoteserver_ipv6 """ - syslog = dict(enable=True, remoteserver='[2001:0db8:cafe:affe:0000:0000:0000:0001]:514', logall=True) + """test log_settings remoteserver_ipv6""" + syslog = dict( + enable=True, + remoteserver="[2001:0db8:cafe:affe:0000:0000:0000:0001]:514", + logall=True, + ) command = "update log_settings syslog set enable=True, remoteserver='[2001:0db8:cafe:affe:0000:0000:0000:0001]:514', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_remoteserver_ipv4_invalid_port1(self): - """ test log_settings remoteserver_ipv4_invalid_port1 """ - syslog = dict(enable=True, remoteserver='1234:0', logall=True) + """test log_settings remoteserver_ipv4_invalid_port1""" + syslog = dict(enable=True, remoteserver="1234:0", logall=True) msg = "Invalid port 0" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_remoteserver_ipv4_invalid_port2(self): - """ test log_settings remoteserver_ipv4_invalid_port1 """ - syslog = dict(enable=True, remoteserver='1234:65536', logall=True) + """test log_settings remoteserver_ipv4_invalid_port1""" + syslog = dict(enable=True, remoteserver="1234:65536", logall=True) msg = "Invalid port 65536" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_ipproto_invalid(self): - """ test syslog ipproto invalid """ - syslog = dict(ipproto='ipv5') - msg = 'value of ipproto must be one of: ipv4, ipv6, got: ipv5' + """test syslog ipproto invalid""" + syslog = dict(ipproto="ipv5") + msg = "value of ipproto must be one of: ipv4, ipv6, got: ipv5" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_logall_true(self): - """ test syslog format logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True) + """test syslog format logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', logall=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_logall_false(self): - """ test syslog format logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False) + """test syslog format logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_system_true(self): - """ test syslog format system=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', system=True) + """test syslog format system=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", system=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', system=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_system_false(self): - """ test syslog format system=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', system=False) + """test syslog format system=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", system=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_system_invalid_with_logall(self): - """ test syslog format system=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, system=True) - msg = 'system = True is invalid when logall is True' + """test syslog format system=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, system=True) + msg = "system = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_system_valid_with_logall(self): - """ test syslog format system=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, system=True) + """test syslog format system=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, system=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', system=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilter_true(self): - """ test syslog format logfilter=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logfilter=True) + """test syslog format logfilter=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logfilter=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', filter=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilter_false(self): - """ test syslog format logfilter=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logfilter=False) + """test syslog format logfilter=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logfilter=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_logfilter_invalid_with_logall(self): - """ test syslog format logfilter=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, logfilter=True) - msg = 'logfilter = True is invalid when logall is True' + """test syslog format logfilter=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, logfilter=True) + msg = "logfilter = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_logfilter_valid_with_logall(self): - """ test syslog format logfilter=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, logfilter=True) + """test syslog format logfilter=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, logfilter=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', filter=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_resolver_true(self): - """ test syslog format resolver=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', resolver=True) + """test syslog format resolver=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", resolver=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', resolver=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_resolver_false(self): - """ test syslog format resolver=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', resolver=False) + """test syslog format resolver=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", resolver=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_resolver_invalid_with_logall(self): - """ test syslog format resolver=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, resolver=True) - msg = 'resolver = True is invalid when logall is True' + """test syslog format resolver=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, resolver=True) + msg = "resolver = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_resolver_valid_with_logall(self): - """ test syslog format resolver=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, resolver=True) + """test syslog format resolver=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, resolver=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', resolver=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_dhcp_true(self): - """ test syslog format dhcp=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', dhcp=True) + """test syslog format dhcp=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", dhcp=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', dhcp=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_dhcp_false(self): - """ test syslog format dhcp=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', dhcp=False) + """test syslog format dhcp=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", dhcp=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_dhcp_invalid_with_logall(self): - """ test syslog format dhcp=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, dhcp=True) - msg = 'dhcp = True is invalid when logall is True' + """test syslog format dhcp=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, dhcp=True) + msg = "dhcp = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_dhcp_valid_with_logall(self): - """ test syslog format dhcp=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, dhcp=True) + """test syslog format dhcp=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, dhcp=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', dhcp=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_ppp_true(self): - """ test syslog format ppp=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', ppp=True) + """test syslog format ppp=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", ppp=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', ppp=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_ppp_false(self): - """ test syslog format ppp=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', ppp=False) + """test syslog format ppp=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", ppp=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_ppp_invalid_with_logall(self): - """ test syslog format ppp=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, ppp=True) - msg = 'ppp = True is invalid when logall is True' + """test syslog format ppp=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, ppp=True) + msg = "ppp = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_ppp_valid_with_logall(self): - """ test syslog format ppp=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, ppp=True) + """test syslog format ppp=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, ppp=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', ppp=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_auth_true(self): - """ test syslog format auth=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', auth=True) + """test syslog format auth=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", auth=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', auth=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_auth_false(self): - """ test syslog format auth=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', auth=False) + """test syslog format auth=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", auth=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_auth_invalid_with_logall(self): - """ test syslog format auth=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, auth=True) - msg = 'auth = True is invalid when logall is True' + """test syslog format auth=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, auth=True) + msg = "auth = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_auth_valid_with_logall(self): - """ test syslog format auth=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, auth=True) + """test syslog format auth=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, auth=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', auth=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_portalauth_true(self): - """ test syslog format portalauth=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', portalauth=True) + """test syslog format portalauth=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", portalauth=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', portalauth=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_portalauth_false(self): - """ test syslog format portalauth=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', portalauth=False) + """test syslog format portalauth=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", portalauth=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_portalauth_invalid_with_logall(self): - """ test syslog format portalauth=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, portalauth=True) - msg = 'portalauth = True is invalid when logall is True' + """test syslog format portalauth=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, portalauth=True) + msg = "portalauth = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_portalauth_valid_with_logall(self): - """ test syslog format portalauth=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, portalauth=True) + """test syslog format portalauth=true, logall=false""" + syslog = dict( + enable=True, remoteserver="1.2.3.4", logall=False, portalauth=True + ) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', portalauth=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_vpn_true(self): - """ test syslog format vpn=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', vpn=True) + """test syslog format vpn=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", vpn=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', vpn=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_vpn_false(self): - """ test syslog format vpn=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', vpn=False) + """test syslog format vpn=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", vpn=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_vpn_invalid_with_logall(self): - """ test syslog format vpn=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, vpn=True) - msg = 'vpn = True is invalid when logall is True' + """test syslog format vpn=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, vpn=True) + msg = "vpn = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_vpn_valid_with_logall(self): - """ test syslog format vpn=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, vpn=True) + """test syslog format vpn=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, vpn=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', vpn=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_dpinger_true(self): - """ test syslog format dpinger=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', dpinger=True) + """test syslog format dpinger=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", dpinger=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', dpinger=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_dpinger_false(self): - """ test syslog format dpinger=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', dpinger=False) + """test syslog format dpinger=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", dpinger=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_dpinger_invalid_with_logall(self): - """ test syslog format dpinger=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, dpinger=True) - msg = 'dpinger = True is invalid when logall is True' + """test syslog format dpinger=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, dpinger=True) + msg = "dpinger = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_dpinger_valid_with_logall(self): - """ test syslog format dpinger=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, dpinger=True) + """test syslog format dpinger=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, dpinger=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', dpinger=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_routing_true(self): - """ test syslog format routing=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', routing=True) + """test syslog format routing=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", routing=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', routing=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_routing_false(self): - """ test syslog format routing=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', routing=False) + """test syslog format routing=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", routing=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_routing_invalid_with_logall(self): - """ test syslog format routing=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, routing=True) - msg = 'routing = True is invalid when logall is True' + """test syslog format routing=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, routing=True) + msg = "routing = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_routing_valid_with_logall(self): - """ test syslog format routing=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, routing=True) + """test syslog format routing=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, routing=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', routing=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_ntpd_true(self): - """ test syslog format ntpd=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', ntpd=True) + """test syslog format ntpd=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", ntpd=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', ntpd=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_ntpd_false(self): - """ test syslog format ntpd=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', ntpd=False) + """test syslog format ntpd=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", ntpd=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_ntpd_invalid_with_logall(self): - """ test syslog format ntpd=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, ntpd=True) - msg = 'ntpd = True is invalid when logall is True' + """test syslog format ntpd=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, ntpd=True) + msg = "ntpd = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_ntpd_valid_with_logall(self): - """ test syslog format ntpd=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, ntpd=True) + """test syslog format ntpd=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, ntpd=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', ntpd=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_hostapd_true(self): - """ test syslog format hostapd=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', hostapd=True) + """test syslog format hostapd=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", hostapd=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', hostapd=True" self.do_module_test(syslog, command=command, state=None) def test_syslog_hostapd_false(self): - """ test syslog format hostapd=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', hostapd=False) + """test syslog format hostapd=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", hostapd=False) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4'" self.do_module_test(syslog, command=command, state=None) def test_syslog_hostapd_invalid_with_logall(self): - """ test syslog format hostapd=true, logall=true """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=True, hostapd=True) - msg = 'hostapd = True is invalid when logall is True' + """test syslog format hostapd=true, logall=true""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=True, hostapd=True) + msg = "hostapd = True is invalid when logall is True" self.do_module_test(syslog, msg=msg, state=None, failed=True) def test_syslog_hostapd_valid_with_logall(self): - """ test syslog format hostapd=true, logall=false """ - syslog = dict(enable=True, remoteserver='1.2.3.4', logall=False, hostapd=True) + """test syslog format hostapd=true, logall=false""" + syslog = dict(enable=True, remoteserver="1.2.3.4", logall=False, hostapd=True) command = "update log_settings syslog set enable=True, remoteserver='1.2.3.4', hostapd=True" self.do_module_test(syslog, command=command, state=None) diff --git a/tests/unit/plugins/modules/test_pfsense_nat_outbound.py b/tests/unit/plugins/modules/test_pfsense_nat_outbound.py index 42056e3f..801f9763 100644 --- a/tests/unit/plugins/modules/test_pfsense_nat_outbound.py +++ b/tests/unit/plugins/modules/test_pfsense_nat_outbound.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,73 +12,80 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_nat_outbound -from ansible_collections.pfsensible.core.plugins.module_utils.nat_outbound import PFSenseNatOutboundModule +from ansible_collections.pfsensible.core.plugins.module_utils.nat_outbound import ( + PFSenseNatOutboundModule, +) from .pfsense_module import TestPFSenseModule from ipaddress import ip_address, IPv4Address class TestPFSenseNatOutboundModule(TestPFSenseModule): - module = pfsense_nat_outbound def __init__(self, *args, **kwargs): super(TestPFSenseNatOutboundModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_nat_outbound.xml' + self.config_file = "pfsense_nat_outbound.xml" self.pfmodule = PFSenseNatOutboundModule @staticmethod def is_ipv4_address(address): - """ test if address is a valid ipv4 address """ + """test if address is a valid ipv4 address""" try: - addr = ip_address(u'{0}'.format(address)) + addr = ip_address("{0}".format(address)) return isinstance(addr, IPv4Address) except ValueError: pass return False def parse_address(self, name, addr, field, invert=False): - """ return address parsed in dict """ - parts = addr.split(':') + """return address parsed in dict""" + parts = addr.split(":") res = {} port = None - if parts[0] == 'NET': + if parts[0] == "NET": res[field] = parts[1] if len(parts) > 2: - port = parts[2].replace('-', ':') + port = parts[2].replace("-", ":") else: - if parts[0] == 'any': - if name == 'source': - res[field] = 'any' + if parts[0] == "any": + if name == "source": + res[field] = "any" else: - res['any'] = None - elif parts[0] == '(self)': - res[field] = '(self)' - elif parts[0] in ['lan', 'vpn', 'vt1', 'lan_100']: + res["any"] = None + elif parts[0] == "(self)": + res[field] = "(self)" + elif parts[0] in ["lan", "vpn", "vt1", "lan_100"]: res[field] = self.unalias_interface(parts[0]) else: res[field] = parts[0] - if field in res and self.is_ipv4_address(res[field]) and res[field].find('/') == -1: - res[field] += '/32' + if ( + field in res + and self.is_ipv4_address(res[field]) + and res[field].find("/") == -1 + ): + res[field] += "/32" if len(parts) > 1: - port = parts[1].replace('-', ':') + port = parts[1].replace("-", ":") if invert: - res['not'] = None + res["not"] = None return (res, port) @staticmethod def reparse_network(value): - if value == '1.2.3.4/24': - return '1.2.3.0/24' - elif value == '2.3.4.5/24': - return '2.3.4.0/24' + if value == "1.2.3.4/24": + return "1.2.3.0/24" + elif value == "2.3.4.5/24": + return "2.3.4.0/24" return value def check_addr(self, params, target_elt, addr, field, port, invert=False): - """ test the addresses definition """ - (addr_dict, port_value) = self.parse_address(addr, params[addr], field, invert=invert) + """test the addresses definition""" + (addr_dict, port_value) = self.parse_address( + addr, params[addr], field, invert=invert + ) addr_elt = self.assert_find_xml_elt(target_elt, addr) for key, value in addr_dict.items(): self.check_value_equal(addr_elt, key, self.reparse_network(value)) @@ -85,75 +93,88 @@ def check_addr(self, params, target_elt, addr, field, port, invert=False): for item_elt in addr_elt: self.assertTrue(item_elt.tag in addr_dict) - self.check_value_equal(target_elt, port, port_value, port == 'sourceport') + self.check_value_equal(target_elt, port, port_value, port == "sourceport") def check_target_addr(self, params, target_elt): - """ test the addresses definition """ - if 'address' not in params or params['address'] == '': - self.assert_xml_elt_is_none_or_empty(target_elt, 'target') - self.assert_xml_elt_is_none_or_empty(target_elt, 'target_subnet') - self.assert_not_find_xml_elt(target_elt, 'natport') - elif params['address'] == '4.5.6.7:888-999': - self.assert_xml_elt_equal(target_elt, 'target', '4.5.6.7') - self.assert_xml_elt_equal(target_elt, 'target_subnet', '32') - self.assert_xml_elt_equal(target_elt, 'natport', '888:999') - elif params['address'] == '4.5.6.7/24:888-999': - self.assert_xml_elt_equal(target_elt, 'target', '4.5.6.0') - self.assert_xml_elt_equal(target_elt, 'target_subnet', '24') - self.assert_xml_elt_equal(target_elt, 'natport', '888:999') + """test the addresses definition""" + if "address" not in params or params["address"] == "": + self.assert_xml_elt_is_none_or_empty(target_elt, "target") + self.assert_xml_elt_is_none_or_empty(target_elt, "target_subnet") + self.assert_not_find_xml_elt(target_elt, "natport") + elif params["address"] == "4.5.6.7:888-999": + self.assert_xml_elt_equal(target_elt, "target", "4.5.6.7") + self.assert_xml_elt_equal(target_elt, "target_subnet", "32") + self.assert_xml_elt_equal(target_elt, "natport", "888:999") + elif params["address"] == "4.5.6.7/24:888-999": + self.assert_xml_elt_equal(target_elt, "target", "4.5.6.0") + self.assert_xml_elt_equal(target_elt, "target_subnet", "24") + self.assert_xml_elt_equal(target_elt, "natport", "888:999") @staticmethod def md5(value): - if value == 'acme_key': - return '0xfdc529cc680c4e8c74efbf114ec436fb' + if value == "acme_key": + return "0xfdc529cc680c4e8c74efbf114ec436fb" return value def check_target_elt(self, obj, target_elt, target_idx=-1): - """ test the xml definition """ - self.check_addr(obj, target_elt, 'source', 'network', 'sourceport') - self.check_addr(obj, target_elt, 'destination', 'network', 'dstport', invert=obj.get('invert')) + """test the xml definition""" + self.check_addr(obj, target_elt, "source", "network", "sourceport") + self.check_addr( + obj, + target_elt, + "destination", + "network", + "dstport", + invert=obj.get("invert"), + ) self.check_target_addr(obj, target_elt) - self.check_param_equal_or_not_find(obj, target_elt, 'disabled') - self.check_param_equal_or_not_find(obj, target_elt, 'nonat') - self.check_param_equal_or_not_find(obj, target_elt, 'staticnatport') - self.check_param_equal_or_not_find(obj, target_elt, 'nosync') - self.check_param_equal_or_not_find(obj, target_elt, 'nonat') + self.check_param_equal_or_not_find(obj, target_elt, "disabled") + self.check_param_equal_or_not_find(obj, target_elt, "nonat") + self.check_param_equal_or_not_find(obj, target_elt, "staticnatport") + self.check_param_equal_or_not_find(obj, target_elt, "nosync") + self.check_param_equal_or_not_find(obj, target_elt, "nonat") - self.check_value_equal(target_elt, 'interface', self.unalias_interface(obj['interface'])) - self.check_param_equal(obj, target_elt, 'ipprotocol', 'inet46', not_find_val='inet46') - self.check_param_equal(obj, target_elt, 'protocol', 'any', not_find_val='any') - self.check_param_equal(obj, target_elt, 'poolopts') - self.check_value_equal(target_elt, 'source_hash_key', self.md5(obj.get('source_hash_key'))) + self.check_value_equal( + target_elt, "interface", self.unalias_interface(obj["interface"]) + ) + self.check_param_equal( + obj, target_elt, "ipprotocol", "inet46", not_find_val="inet46" + ) + self.check_param_equal(obj, target_elt, "protocol", "any", not_find_val="any") + self.check_param_equal(obj, target_elt, "poolopts") + self.check_value_equal( + target_elt, "source_hash_key", self.md5(obj.get("source_hash_key")) + ) self.check_rule_idx(obj, target_idx) def check_rule_idx(self, params, target_idx): - """ test the xml position """ - nat_elt = self.assert_find_xml_elt(self.xml_result, 'nat') - rules_elt = self.assert_find_xml_elt(nat_elt, 'outbound') + """test the xml position""" + nat_elt = self.assert_find_xml_elt(self.xml_result, "nat") + rules_elt = self.assert_find_xml_elt(nat_elt, "outbound") idx = -1 for rule_elt in rules_elt: - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue idx += 1 - descr_elt = rule_elt.find('descr') + descr_elt = rule_elt.find("descr") self.assertIsNotNone(descr_elt) self.assertIsNotNone(descr_elt.text) - if descr_elt.text == params['descr']: + if descr_elt.text == params["descr"]: self.assertEqual(idx, target_idx) return - self.fail('rule not found ' + str(idx)) + self.fail("rule not found " + str(idx)) def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - nat_elt = self.assert_find_xml_elt(self.xml_result, 'nat') - outbount_elt = self.assert_find_xml_elt(nat_elt, 'outbound') + """get the generated xml definition""" + nat_elt = self.assert_find_xml_elt(self.xml_result, "nat") + outbount_elt = self.assert_find_xml_elt(nat_elt, "outbound") for item in outbount_elt: - descr_elt = item.find('descr') - if descr_elt is not None and descr_elt.text == obj['descr']: + descr_elt = item.find("descr") + if descr_elt is not None and descr_elt.text == obj["descr"]: return item return None @@ -162,14 +183,25 @@ def get_target_elt(self, obj, absent=False, module_result=None): # tests # def test_nat_outbound_create(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="any", + destination="1.2.3.4:443", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_aliases(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='srv_admin:port_ssh', destination='srv_admin:port_ssh', address='srv_admin:port_ssh') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="srv_admin:port_ssh", + destination="srv_admin:port_ssh", + address="srv_admin:port_ssh", + ) command = ( "create nat_outbound 'https-source-rewriting', interface='lan', source='srv_admin:port_ssh', " "destination='srv_admin:port_ssh', address='srv_admin:port_ssh'" @@ -177,132 +209,237 @@ def test_nat_outbound_create_aliases(self): self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_address(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', address='4.5.6.7:888-999') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="any", + destination="1.2.3.4:443", + address="4.5.6.7:888-999", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', address='4.5.6.7/32:888-999'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_address_net(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', address='4.5.6.7/24:888-999') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="any", + destination="1.2.3.4:443", + address="4.5.6.7/24:888-999", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', address='4.5.6.0/24:888-999'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_networks(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='1.2.3.4/24', destination='2.3.4.5/24:443') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="1.2.3.4/24", + destination="2.3.4.5/24:443", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='1.2.3.4/24', destination='2.3.4.5/24:443'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_ipprotocol(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', ipprotocol='inet', source='any', destination='1.2.3.4:443') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + ipprotocol="inet", + source="any", + destination="1.2.3.4:443", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', ipprotocol='inet', source='any', destination='1.2.3.4:443'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_protocol(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', protocol='tcp', source='any', destination='1.2.3.4:443') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + protocol="tcp", + source="any", + destination="1.2.3.4:443", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', protocol='tcp', source='any', destination='1.2.3.4:443'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_networks_invert(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='1.2.3.4/24', destination='2.3.4.5/24:443', invert=True) + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="1.2.3.4/24", + destination="2.3.4.5/24:443", + invert=True, + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='1.2.3.4/24', destination='2.3.4.5/24:443', invert=True" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_interface_destination_network(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='1.2.3.4/24', destination='NET:lan:443') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="1.2.3.4/24", + destination="NET:lan:443", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='1.2.3.4/24', destination='NET:lan:443'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_interface_source_network(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='NET:lan', destination='2.3.4.5/24:443') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="NET:lan", + destination="2.3.4.5/24:443", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='NET:lan', destination='2.3.4.5/24:443'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_top(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', after='top') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="any", + destination="1.2.3.4:443", + after="top", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', after='top'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_outbound_create_after(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', after='one rule') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="any", + destination="1.2.3.4:443", + after="one rule", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', after='one rule'" self.do_module_test(obj, command=command, target_idx=1) def test_nat_outbound_create_before(self): - """ test """ - obj = dict(descr='https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', before='another rule') + """test""" + obj = dict( + descr="https-source-rewriting", + interface="lan", + source="any", + destination="1.2.3.4:443", + before="another rule", + ) command = "create nat_outbound 'https-source-rewriting', interface='lan', source='any', destination='1.2.3.4:443', before='another rule'" self.do_module_test(obj, command=command, target_idx=1) def test_nat_outbound_create_with_sourcehashkey(self): - """ test """ - obj = dict(descr='valid', interface='lan', source='any', destination='1.2.3.4:443', source_hash_key='0x12345678901234567890123456789012') + """test""" + obj = dict( + descr="valid", + interface="lan", + source="any", + destination="1.2.3.4:443", + source_hash_key="0x12345678901234567890123456789012", + ) command = "create nat_outbound 'valid', interface='lan', source='any', destination='1.2.3.4:443', source_hash_key='0x12345678901234567890123456789012'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_create_with_sourcehashkey_str(self): - """ test """ - obj = dict(descr='valid', interface='lan', source='any', destination='1.2.3.4:443', source_hash_key='acme_key') + """test""" + obj = dict( + descr="valid", + interface="lan", + source="any", + destination="1.2.3.4:443", + source_hash_key="acme_key", + ) command = "create nat_outbound 'valid', interface='lan', source='any', destination='1.2.3.4:443', source_hash_key='0xfdc529cc680c4e8c74efbf114ec436fb'" self.do_module_test(obj, command=command, target_idx=3) def test_nat_outbound_update_noop(self): - """ test """ - obj = dict(descr='one rule', interface='wan', source='any', destination='any') + """test""" + obj = dict(descr="one rule", interface="wan", source="any", destination="any") self.do_module_test(obj, target_idx=0, changed=False) def test_nat_outbound_update_bottom(self): - """ test """ - obj = dict(descr='one rule', interface='wan', source='any', destination='any', before='bottom') + """test""" + obj = dict( + descr="one rule", + interface="wan", + source="any", + destination="any", + before="bottom", + ) command = "update nat_outbound 'one rule' set before='bottom'" self.do_module_test(obj, command=command, target_idx=2) def test_nat_outbound_update_top(self): - """ test """ - obj = dict(descr='another rule', interface='wan', source='any', destination='any', after='top') + """test""" + obj = dict( + descr="another rule", + interface="wan", + source="any", + destination="any", + after="top", + ) command = "update nat_outbound 'another rule' set after='top'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_outbound_update_source(self): - """ test """ - obj = dict(descr='one rule', interface='wan', source='(self):123', destination='any') + """test""" + obj = dict( + descr="one rule", interface="wan", source="(self):123", destination="any" + ) command = "update nat_outbound 'one rule' set source='(self):123'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_outbound_update_destination(self): - """ test """ - obj = dict(descr='one rule', interface='wan', source='any', destination='1.2.3.4:555') + """test""" + obj = dict( + descr="one rule", interface="wan", source="any", destination="1.2.3.4:555" + ) command = "update nat_outbound 'one rule' set destination='1.2.3.4/32:555'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_outbound_update_interface(self): - """ test """ - obj = dict(descr='one rule', interface='lan_100', source='any', destination='any') + """test""" + obj = dict( + descr="one rule", interface="lan_100", source="any", destination="any" + ) command = "update nat_outbound 'one rule' set interface='lan_100'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_outbound_delete(self): - """ test """ - obj = dict(descr='one rule') + """test""" + obj = dict(descr="one rule") command = "delete nat_outbound 'one rule'" self.do_module_test(obj, command=command, delete=True) def test_nat_outbound_invalid_sourcehashkey_hex(self): - """ test """ - obj = dict(descr='invalid', interface='lan', source='any', destination='1.2.3.4:443', source_hash_key='0xg2345678901234567890123456789012') + """test""" + obj = dict( + descr="invalid", + interface="lan", + source="any", + destination="1.2.3.4:443", + source_hash_key="0xg2345678901234567890123456789012", + ) msg = 'Incorrect format for source-hash key, "0x" must be followed by exactly 32 hexadecimal characters.' self.do_module_test(obj, msg=msg, failed=True) def test_nat_outbound_invalid_sourcehashkey_len(self): - """ test """ - obj = dict(descr='invalid', interface='lan', source='any', destination='1.2.3.4:443', source_hash_key='0x1234567890123456789012345678901') + """test""" + obj = dict( + descr="invalid", + interface="lan", + source="any", + destination="1.2.3.4:443", + source_hash_key="0x1234567890123456789012345678901", + ) msg = 'Incorrect format for source-hash key, "0x" must be followed by exactly 32 hexadecimal characters.' self.do_module_test(obj, msg=msg, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_nat_port_forward.py b/tests/unit/plugins/modules/test_pfsense_nat_port_forward.py index 07ca84b8..76b2fb3e 100644 --- a/tests/unit/plugins/modules/test_pfsense_nat_port_forward.py +++ b/tests/unit/plugins/modules/test_pfsense_nat_port_forward.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,90 +12,95 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_nat_port_forward -from ansible_collections.pfsensible.core.plugins.module_utils.nat_port_forward import PFSenseNatPortForwardModule +from ansible_collections.pfsensible.core.plugins.module_utils.nat_port_forward import ( + PFSenseNatPortForwardModule, +) from .pfsense_module import TestPFSenseModule from .test_pfsense_rule import TestPFSenseRuleModule class TestPFSenseNatPortForwardModule(TestPFSenseModule): - module = pfsense_nat_port_forward def __init__(self, *args, **kwargs): super(TestPFSenseNatPortForwardModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_nat_port_forward_config.xml' + self.config_file = "pfsense_nat_port_forward_config.xml" self.pfmodule = PFSenseNatPortForwardModule def check_target_addr(self, params, target_elt): - """ test the addresses definition """ - if params['target'] == '2.3.4.5:443': - self.assert_xml_elt_equal(target_elt, 'target', '2.3.4.5') - self.assert_xml_elt_equal(target_elt, 'local-port', '443') + """test the addresses definition""" + if params["target"] == "2.3.4.5:443": + self.assert_xml_elt_equal(target_elt, "target", "2.3.4.5") + self.assert_xml_elt_equal(target_elt, "local-port", "443") def get_associated_rule_elt(self, params, ruleid): - """ check the associated rule """ + """check the associated rule""" filters = dict() - filters['interface'] = self.unalias_interface(params['interface']) - filters['associated-rule-id'] = ruleid - return self.assert_has_xml_tag('filter', filters) + filters["interface"] = self.unalias_interface(params["interface"]) + filters["associated-rule-id"] = ruleid + return self.assert_has_xml_tag("filter", filters) def check_target_elt(self, obj, target_elt, target_idx=-1): - """ test the xml definition """ + """test the xml definition""" rules_tester = TestPFSenseRuleModule() - rules_tester.check_rule_elt_addr(obj, target_elt, 'source') + rules_tester.check_rule_elt_addr(obj, target_elt, "source") # checking destination address and ports - rules_tester.check_rule_elt_addr(obj, target_elt, 'destination') + rules_tester.check_rule_elt_addr(obj, target_elt, "destination") self.check_target_addr(obj, target_elt) - self.check_param_equal_or_not_find(obj, target_elt, 'disabled') - self.check_param_equal_or_not_find(obj, target_elt, 'nordr') - self.check_param_equal_or_not_find(obj, target_elt, 'nosync') - self.check_param_equal_or_not_find(obj, target_elt, 'natreflection', not_find_val='system-default') + self.check_param_equal_or_not_find(obj, target_elt, "disabled") + self.check_param_equal_or_not_find(obj, target_elt, "nordr") + self.check_param_equal_or_not_find(obj, target_elt, "nosync") + self.check_param_equal_or_not_find( + obj, target_elt, "natreflection", not_find_val="system-default" + ) - self.check_value_equal(target_elt, 'interface', self.unalias_interface(obj['interface'])) - self.check_param_equal(obj, target_elt, 'ipprotocol', 'inet') - self.check_param_equal(obj, target_elt, 'protocol', 'tcp') - self.check_param_equal_or_present(obj, target_elt, 'local-port') + self.check_value_equal( + target_elt, "interface", self.unalias_interface(obj["interface"]) + ) + self.check_param_equal(obj, target_elt, "ipprotocol", "inet") + self.check_param_equal(obj, target_elt, "protocol", "tcp") + self.check_param_equal_or_present(obj, target_elt, "local-port") self.check_rule_idx(obj, target_idx) - if 'associated_rule' not in obj: - obj['associated_rule'] = 'associated' + if "associated_rule" not in obj: + obj["associated_rule"] = "associated" - if obj['associated_rule'] == 'none' or obj['associated_rule'] == 'unassociated': - self.assert_xml_elt_is_none_or_empty(target_elt, 'associated-rule-id') - elif obj['associated_rule'] == 'pass': - self.check_value_equal(target_elt, 'associated-rule-id', 'pass') + if obj["associated_rule"] == "none" or obj["associated_rule"] == "unassociated": + self.assert_xml_elt_is_none_or_empty(target_elt, "associated-rule-id") + elif obj["associated_rule"] == "pass": + self.check_value_equal(target_elt, "associated-rule-id", "pass") else: - ruleid_elt = self.assert_find_xml_elt(target_elt, 'associated-rule-id') - self.assertTrue(ruleid_elt.text.startswith('nat_')) + ruleid_elt = self.assert_find_xml_elt(target_elt, "associated-rule-id") + self.assertTrue(ruleid_elt.text.startswith("nat_")) rule_elt = self.get_associated_rule_elt(obj, ruleid_elt.text) - self.assertEqual(rule_elt.find('descr').text, 'NAT ' + obj['descr']) + self.assertEqual(rule_elt.find("descr").text, "NAT " + obj["descr"]) def check_rule_idx(self, params, target_idx): - """ test the xml position """ - rules_elt = self.assert_find_xml_elt(self.xml_result, 'nat') + """test the xml position""" + rules_elt = self.assert_find_xml_elt(self.xml_result, "nat") idx = -1 for rule_elt in rules_elt: - if rule_elt.tag != 'rule': + if rule_elt.tag != "rule": continue idx += 1 - descr_elt = rule_elt.find('descr') + descr_elt = rule_elt.find("descr") self.assertIsNotNone(descr_elt) self.assertIsNotNone(descr_elt.text) - if descr_elt.text == params['descr']: + if descr_elt.text == params["descr"]: self.assertEqual(idx, target_idx) return - self.fail('rule not found ' + str(idx)) + self.fail("rule not found " + str(idx)) def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - rules_elt = self.assert_find_xml_elt(self.xml_result, 'nat') + """get the generated xml definition""" + rules_elt = self.assert_find_xml_elt(self.xml_result, "nat") for item in rules_elt: - descr_elt = item.find('descr') - if descr_elt is not None and descr_elt.text == obj['descr']: + descr_elt = item.find("descr") + if descr_elt is not None and descr_elt.text == obj["descr"]: return item return None @@ -103,16 +109,31 @@ def get_target_elt(self, obj, absent=False, module_result=None): # tests # def test_nat_port_forward_create(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='pass') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:443", + destination="1.2.3.4:443", + target="2.3.4.5:443", + associated_rule="pass", + ) command = ( - "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='pass'" + "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', " + "associated_rule='pass'" ) self.do_module_test(obj, command=command, target_idx=3) def test_nat_port_forward_create_range(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:9000-10000', destination='1.2.3.4:9000-10000', target='2.3.4.5:9000', associated_rule='none') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:9000-10000", + destination="1.2.3.4:9000-10000", + target="2.3.4.5:9000", + associated_rule="none", + ) command = ( "create nat_port_forward 'test_pf', interface='lan', source='any:9000-10000', destination='1.2.3.4:9000-10000', " "target='2.3.4.5:9000', associated_rule='none'" @@ -120,15 +141,29 @@ def test_nat_port_forward_create_range(self): self.do_module_test(obj, command=command, target_idx=3) def test_nat_port_forward_create_associated(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='associated') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:443", + destination="1.2.3.4:443", + target="2.3.4.5:443", + associated_rule="associated", + ) cmd1 = "create rule 'NAT test_pf' on 'lan', source='any:443', destination='2.3.4.5:443', protocol='tcp'" cmd2 = "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443'" self.do_module_test(obj, command=[cmd1, cmd2], target_idx=3) def test_nat_port_forward_create_unassociated(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='unassociated') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:443", + destination="1.2.3.4:443", + target="2.3.4.5:443", + associated_rule="unassociated", + ) cmd1 = "create rule 'NAT test_pf' on 'lan', source='any:443', destination='2.3.4.5:443', protocol='tcp'" cmd2 = ( "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', " @@ -137,8 +172,16 @@ def test_nat_port_forward_create_unassociated(self): self.do_module_test(obj, command=[cmd1, cmd2], target_idx=3) def test_nat_port_forward_create_top(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='pass', after='top') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:443", + destination="1.2.3.4:443", + target="2.3.4.5:443", + associated_rule="pass", + after="top", + ) command = ( "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', " "associated_rule='pass', after='top'" @@ -146,8 +189,16 @@ def test_nat_port_forward_create_top(self): self.do_module_test(obj, command=command, target_idx=0) def test_nat_port_forward_create_after(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='pass', after='one') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:443", + destination="1.2.3.4:443", + target="2.3.4.5:443", + associated_rule="pass", + after="one", + ) command = ( "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', " "associated_rule='pass', after='one'" @@ -155,8 +206,16 @@ def test_nat_port_forward_create_after(self): self.do_module_test(obj, command=command, target_idx=1) def test_nat_port_forward_create_before(self): - """ test """ - obj = dict(descr='test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', associated_rule='pass', before='two') + """test""" + obj = dict( + descr="test_pf", + interface="lan", + source="any:443", + destination="1.2.3.4:443", + target="2.3.4.5:443", + associated_rule="pass", + before="two", + ) command = ( "create nat_port_forward 'test_pf', interface='lan', source='any:443', destination='1.2.3.4:443', target='2.3.4.5:443', " "associated_rule='pass', before='two'" @@ -164,71 +223,145 @@ def test_nat_port_forward_create_before(self): self.do_module_test(obj, command=command, target_idx=1) def test_nat_port_forward_create_icmp(self): - """ test """ - obj = dict(descr='test_pf', interface='wan', protocol='icmp', source='any', destination='1.2.3.4', target='2.3.4.5', associated_rule='associated') + """test""" + obj = dict( + descr="test_pf", + interface="wan", + protocol="icmp", + source="any", + destination="1.2.3.4", + target="2.3.4.5", + associated_rule="associated", + ) command = [ "create rule 'NAT test_pf' on 'wan', source='any', destination='2.3.4.5', protocol='icmp'", - "create nat_port_forward 'test_pf', interface='wan', protocol='icmp', source='any', destination='1.2.3.4', target='2.3.4.5'" + "create nat_port_forward 'test_pf', interface='wan', protocol='icmp', source='any', destination='1.2.3.4', target='2.3.4.5'", ] self.do_module_test(obj, command=command, target_idx=3) def test_nat_port_forward_create_tcp_fail_no_port(self): - """ test """ - obj = dict(descr='test_pf', interface='wan', source='any', destination='1.2.3.4', target='2.3.4.5', associated_rule='associated') + """test""" + obj = dict( + descr="test_pf", + interface="wan", + source="any", + destination="1.2.3.4", + target="2.3.4.5", + associated_rule="associated", + ) msg = 'Must specify a target port with protocol "tcp".' self.do_module_test(obj, failed=True, msg=msg) def test_nat_port_forward_create_icmp_fail_port(self): - """ test """ - obj = dict(descr='test_pf', interface='wan', protocol='icmp', source='any', destination='1.2.3.4', target='2.3.4.5:443', associated_rule='associated') + """test""" + obj = dict( + descr="test_pf", + interface="wan", + protocol="icmp", + source="any", + destination="1.2.3.4", + target="2.3.4.5:443", + associated_rule="associated", + ) msg = 'Cannot specify a target port with protocol "icmp".' self.do_module_test(obj, failed=True, msg=msg) def test_nat_port_forward_update_noop(self): - """ test """ - obj = dict(descr='one', interface='wan', source='any', destination='IP:wan:22022', target='10.255.1.20:22', associated_rule='none') + """test""" + obj = dict( + descr="one", + interface="wan", + source="any", + destination="IP:wan:22022", + target="10.255.1.20:22", + associated_rule="none", + ) self.do_module_test(obj, target_idx=0, changed=False) def test_nat_port_forward_update_bottom(self): - """ test """ - obj = dict(descr='one', interface='wan', source='any', destination='IP:wan:22022', target='10.255.1.20:22', associated_rule='none', before='bottom') + """test""" + obj = dict( + descr="one", + interface="wan", + source="any", + destination="IP:wan:22022", + target="10.255.1.20:22", + associated_rule="none", + before="bottom", + ) command = "update nat_port_forward 'one' set before='bottom'" self.do_module_test(obj, command=command, target_idx=2) def test_nat_port_forward_update_top(self): - """ test """ - obj = dict(descr='last', interface='wan', source='any', destination='IP:wan:22022', target='10.255.1.20:22', associated_rule='associated', after='top') + """test""" + obj = dict( + descr="last", + interface="wan", + source="any", + destination="IP:wan:22022", + target="10.255.1.20:22", + associated_rule="associated", + after="top", + ) command = "update nat_port_forward 'last' set after='top'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_port_forward_update_source(self): - """ test """ - obj = dict(descr='one', interface='wan', source='1.2.3.4', destination='IP:wan:22022', target='10.255.1.20:22', associated_rule='none') + """test""" + obj = dict( + descr="one", + interface="wan", + source="1.2.3.4", + destination="IP:wan:22022", + target="10.255.1.20:22", + associated_rule="none", + ) command = "update nat_port_forward 'one' set source='1.2.3.4'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_port_forward_update_destination(self): - """ test """ - obj = dict(descr='one', interface='wan', source='any', destination='1.2.3.4:22022', target='10.255.1.20:22', associated_rule='none') + """test""" + obj = dict( + descr="one", + interface="wan", + source="any", + destination="1.2.3.4:22022", + target="10.255.1.20:22", + associated_rule="none", + ) command = "update nat_port_forward 'one' set destination='1.2.3.4:22022'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_port_forward_update_interface(self): - """ test """ - obj = dict(descr='one', interface='vpn', source='any', destination='IP:wan:22022', target='10.255.1.20:22', associated_rule='none') + """test""" + obj = dict( + descr="one", + interface="vpn", + source="any", + destination="IP:wan:22022", + target="10.255.1.20:22", + associated_rule="none", + ) command = "update nat_port_forward 'one' set interface='vpn'" self.do_module_test(obj, command=command, target_idx=0) def test_nat_port_forward_update_interface_associated(self): - """ test """ - obj = dict(descr='last', interface='lan_100', source='any', destination='IP:wan:22022', target='10.255.1.20:22', associated_rule='associated') + """test""" + obj = dict( + descr="last", + interface="lan_100", + source="any", + destination="IP:wan:22022", + target="10.255.1.20:22", + associated_rule="associated", + ) cmd1 = "delete rule 'NAT last' on 'wan'" cmd2 = "create rule 'NAT last' on 'lan_100', source='any', destination='10.255.1.20:22', protocol='tcp'" cmd3 = "update nat_port_forward 'last' set interface='lan_100'" self.do_module_test(obj, command=[cmd1, cmd2, cmd3], target_idx=2) def test_nat_port_forward_delete(self): - """ test """ - obj = dict(descr='one') + """test""" + obj = dict(descr="one") command = "delete nat_port_forward 'one'" self.do_module_test(obj, command=command, delete=True) diff --git a/tests/unit/plugins/modules/test_pfsense_openvpn_override.py b/tests/unit/plugins/modules/test_pfsense_openvpn_override.py index 09665992..95634d5a 100644 --- a/tests/unit/plugins/modules/test_pfsense_openvpn_override.py +++ b/tests/unit/plugins/modules/test_pfsense_openvpn_override.py @@ -1,7 +1,8 @@ # Copyright: (c) 2022, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -15,77 +16,101 @@ class TestPFSenseOpenVPNOverrideModule(TestPFSenseModule): - module = pfsense_openvpn_override def __init__(self, *args, **kwargs): super(TestPFSenseOpenVPNOverrideModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_openvpn_config.xml' + self.config_file = "pfsense_openvpn_config.xml" self.pfmodule = pfsense_openvpn_override.PFSenseOpenVPNOverrideModule @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ - root_elt = self.xml_result.getroot().find('openvpn') - result = root_elt.findall("openvpn-csc[common_name='{0}']".format(obj['name'])) + """return target elt from XML""" + root_elt = self.xml_result.getroot().find("openvpn") + result = root_elt.findall("openvpn-csc[common_name='{0}']".format(obj["name"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple OpenVPN overrides for name {0}.'.format(obj['name'])) + self.fail( + "Found multiple OpenVPN overrides for name {0}.".format(obj["name"]) + ) else: return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ - - self.check_param_equal(obj, target_elt, 'name', xml_field='common_name') - self.check_param_bool(obj, target_elt, 'disable') - self.check_param_bool(obj, target_elt, 'block', default=False, value_true='yes') - self.check_param_equal(obj, target_elt, 'tunnel_network') - self.check_param_equal(obj, target_elt, 'tunnel_networkv6') - self.check_param_equal(obj, target_elt, 'local_network') - self.check_param_equal(obj, target_elt, 'local_networkv6') - self.check_param_equal(obj, target_elt, 'remote_network') - self.check_param_equal(obj, target_elt, 'remote_networkv6') - self.check_param_bool(obj, target_elt, 'gwredir', default=False, value_true='yes') - self.check_param_bool(obj, target_elt, 'push_reset', default=False, value_true='yes') + """check XML definition of target elt""" + + self.check_param_equal(obj, target_elt, "name", xml_field="common_name") + self.check_param_bool(obj, target_elt, "disable") + self.check_param_bool(obj, target_elt, "block", default=False, value_true="yes") + self.check_param_equal(obj, target_elt, "tunnel_network") + self.check_param_equal(obj, target_elt, "tunnel_networkv6") + self.check_param_equal(obj, target_elt, "local_network") + self.check_param_equal(obj, target_elt, "local_networkv6") + self.check_param_equal(obj, target_elt, "remote_network") + self.check_param_equal(obj, target_elt, "remote_networkv6") + self.check_param_bool( + obj, target_elt, "gwredir", default=False, value_true="yes" + ) + self.check_param_bool( + obj, target_elt, "push_reset", default=False, value_true="yes" + ) ############## # tests # def test_openvpn_override_create(self): - """ test creation of a new OpenVPN override """ - obj = dict(name='vpnuser1', block=True) - self.do_module_test(obj, command="create openvpn_override 'vpnuser1', common_name='vpnuser1'") + """test creation of a new OpenVPN override""" + obj = dict(name="vpnuser1", block=True) + self.do_module_test( + obj, command="create openvpn_override 'vpnuser1', common_name='vpnuser1'" + ) def test_openvpn_override_delete(self): - """ test deletion of a OpenVPN override """ - obj = dict(name='delvpnuser') - self.do_module_test(obj, command="delete openvpn_override 'delvpnuser'", delete=True) + """test deletion of a OpenVPN override""" + obj = dict(name="delvpnuser") + self.do_module_test( + obj, command="delete openvpn_override 'delvpnuser'", delete=True + ) def test_openvpn_override_update_noop(self): - """ test not updating a OpenVPN override """ - obj = dict(name='delvpnuser', gwredir=True, server_list=1, custom_options='ifconfig-push 10.8.0.1 255.255.255.0') + """test not updating a OpenVPN override""" + obj = dict( + name="delvpnuser", + gwredir=True, + server_list=1, + custom_options="ifconfig-push 10.8.0.1 255.255.255.0", + ) self.do_module_test(obj, changed=False) def test_openvpn_override_update_network(self): - """ test updating network of a OpenVPN override """ - obj = dict(name='delvpnuser', gwredir=True, server_list=1, custom_options='ifconfig-push 10.8.0.1 255.255.255.0', tunnel_network='10.10.10.10/24') + """test updating network of a OpenVPN override""" + obj = dict( + name="delvpnuser", + gwredir=True, + server_list=1, + custom_options="ifconfig-push 10.8.0.1 255.255.255.0", + tunnel_network="10.10.10.10/24", + ) self.do_module_test(obj, command="update openvpn_override 'delvpnuser' set ") ############## # misc # def test_create_openvpn_override_invalid_network(self): - """ test creation of a new OpenVPN override with invalid network """ - obj = dict(name='delvpnuser', remote_network='30.4.3.3/24') - self.do_module_test(obj, failed=True, msg='A valid IPv4 network must be specified for remote_network.') + """test creation of a new OpenVPN override with invalid network""" + obj = dict(name="delvpnuser", remote_network="30.4.3.3/24") + self.do_module_test( + obj, + failed=True, + msg="A valid IPv4 network must be specified for remote_network.", + ) def test_delete_nonexistent_openvpn_override(self): - """ test deletion of an nonexistent OpenVPN override """ - obj = dict(name='novpnuser') - self.do_module_test(obj, commmand=None, state='absent', changed=False) + """test deletion of an nonexistent OpenVPN override""" + obj = dict(name="novpnuser") + self.do_module_test(obj, commmand=None, state="absent", changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_openvpn_server.py b/tests/unit/plugins/modules/test_pfsense_openvpn_server.py index c43e2dbb..d20abf6d 100644 --- a/tests/unit/plugins/modules/test_pfsense_openvpn_server.py +++ b/tests/unit/plugins/modules/test_pfsense_openvpn_server.py @@ -1,7 +1,8 @@ # Copyright: (c) 2022, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import base64 @@ -13,7 +14,9 @@ from ansible_collections.pfsensible.core.plugins.modules import pfsense_openvpn_server from .pfsense_module import TestPFSenseModule -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) CERTIFICATE = ( "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tDQpNSUlFQ0RDQ0F2Q2dBd0lCQWdJSUZqRk9oczFuTXpRd0RRWUpLb1pJaHZjTkFRRUxCUUF3WERFVE1CRUdBMVVFDQpBeE1LYjNCbGJuWndiaTFqWVRF" @@ -29,7 +32,8 @@ "Q002OUh4eFlODQpCa2lpbXd1N09mRmFGZkZDT25NSjhvcStKVGxjMG9vREoxM2xCdHRONkdybnZrUTNQMXdZYkNFTmJuaWxPYVVCDQpUSXJpSHl0TkRRYW91TmEvS1dzN0ZhdW9iY3RCbDF3OWF0b0ha" "c041b2VoVDNyQVR2MUNDQXRqcGFUSklmSlIzDQowSVFPWWtlNG9ZNkRrSXdIcDJ2UFBtb29HZ0l0YlR3M1UrRTQxWVplN3FDbUUvN3pMVFNaa0lNMmx4NnpENDZqDQpEZjRyZ044TVVMNnhpd09Mbzly" "QUp5ckRNM2JEeTJ1QjY0QkVzRFFMa2huUE92ZWtETjQ1NnV6TmpYS0E3VnE4DQpoMS9nekRaSURpK1dYQ1lBY2JnTGhaVkJxdG42MnVtRnBNUkl1dz09DQotLS0tLUVORCBDRVJUSUZJQ0FURS0tLS0t" - "DQo=") + "DQo=" +) TLSKEY = ( "IwojIDIwNDggYml0IE9wZW5WUE4gc3RhdGljIGtleQojCi0tLS0tQkVHSU4gT3BlblZQTiBTdGF0aWMga2V5IFYxLS0tLS0KNjFiY2E4MDk0ZmM4YjA3ZTZlMjE3NzRmNTI0YTIyOWYKNGMzZGZhMDVjZ" @@ -37,173 +41,231 @@ "kzMGRmMzEKMDY2Mzk1MjM2ZWRkYWQ3NDc3YmVjZjJmNDgyNzBlMjUKODM1N2JlMGE1MGUzY2Y0ZjllZTEyZTdkMmM4YTY2YzEKODUwNjBlODM5ZWUyMzdjNTZkZmUzNjA4NjU0NDhhYzgKNjhmM2JhYWQ" "4ODNjNDU3NTdlZTVjMWQ4ZDk5ZjM4ZjcKZGNiZDAwZmI3Nzc2ZWFlYjQ1ZmQwOTBjNGNlYTNmMGMKMzgzNDE0ZTJlYmU4MWNiZGIxZmNlN2M2YmFhMDlkMWYKMTU4OGUzNGRkYzUxY2NjOTE5NDNjNTFh" "OTI2OTE3NWQKNzZiZjdhOWI1ZmM3NDAyNmE3MTVkNGVmODVkYzY2Y2UKMWE5MWQwNjNhODIwZDY4MTc0ODlmYjJkZjNmYzY2MmMKMmU2OWZiMzNiMzM5MjdjYjUyNThkZDQ4M2NkNDE0Y2QKMDJhZWE3Z" - "jA3MmNhZmEwOTY5Yjg5NWVjYzNiYmExNGQKLS0tLS1FTkQgT3BlblZQTiBTdGF0aWMga2V5IFYxLS0tLS0K") + "jA3MmNhZmEwOTY5Yjg5NWVjYzNiYmExNGQKLS0tLS1FTkQgT3BlblZQTiBTdGF0aWMga2V5IFYxLS0tLS0K" +) class TestPFSenseOpenVPNServerModule(TestPFSenseModule): - module = pfsense_openvpn_server def __init__(self, *args, **kwargs): super(TestPFSenseOpenVPNServerModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_openvpn_config.xml' + self.config_file = "pfsense_openvpn_config.xml" self.pfmodule = pfsense_openvpn_server.PFSenseOpenVPNServerModule def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseOpenVPNServerModule, self).setUp() - self.mock_run_command = patch('ansible.module_utils.basic.AnsibleModule.run_command') + self.mock_run_command = patch( + "ansible.module_utils.basic.AnsibleModule.run_command" + ) self.run_command = self.mock_run_command.start() - self.run_command.return_value = (0, base64.b64decode(TLSKEY.encode()).decode(), '') - - self.mock_php = patch('ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php') + self.run_command.return_value = ( + 0, + base64.b64decode(TLSKEY.encode()).decode(), + "", + ) + + self.mock_php = patch( + "ansible_collections.pfsensible.core.plugins.module_utils.pfsense.PFSenseModule.php" + ) self.php = self.mock_php.start() - self.php.return_value = {'SHA256': 'SHA256 (256-bit)'} + self.php.return_value = {"SHA256": "SHA256 (256-bit)"} def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseOpenVPNServerModule, self).tearDown() self.run_command.stop() @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ - root_elt = self.xml_result.getroot().find('openvpn') - result = root_elt.findall("openvpn-server[description='{0}']".format(obj['name'])) + """return target elt from XML""" + root_elt = self.xml_result.getroot().find("openvpn") + result = root_elt.findall( + "openvpn-server[description='{0}']".format(obj["name"]) + ) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple OpenVPN servers for name {0}.'.format(obj['name'])) + self.fail( + "Found multiple OpenVPN servers for name {0}.".format(obj["name"]) + ) else: return None @staticmethod def caref(descr): - """ return refid for ca """ - if descr == 'OpenVPN CA': - return '6209e3cef1e81' - return '' + """return refid for ca""" + if descr == "OpenVPN CA": + return "6209e3cef1e81" + return "" @staticmethod def crlref(descr): - """ return refid for crl """ - if descr == 'OpenVPN CRL': - return '6209e3cef1e81' + """return refid for crl""" + if descr == "OpenVPN CRL": + return "6209e3cef1e81" return None @staticmethod def certref(descr): - """ return refid for cert """ - if descr == 'OpenVPN CERT': - return '6209e3cef1e81' + """return refid for cert""" + if descr == "OpenVPN CERT": + return "6209e3cef1e81" return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ + """check XML definition of target elt""" # Use "generated" key - if 'shared_key' in obj and obj['shared_key'] == 'generate': - obj['shared_key'] = TLSKEY - if 'tls' in obj and obj['tls'] == 'generate': - obj['tls'] = TLSKEY - obj['tls_type'] = 'auth' - - self.check_param_equal(obj, target_elt, 'name', xml_field='description') - self.check_param_equal(obj, target_elt, 'custom_options') - self.check_param_equal(obj, target_elt, 'mode', default='ptp_tls') - if obj['mode'] == 'server_tls_user': - self.check_list_param_equal(obj, target_elt, 'authmode') - if obj['mode'] == 'p2p_shared_key': - self.check_param_equal(obj, target_elt, 'shared_key') - self.check_param_equal(obj, target_elt, 'dev_mode', default='tun') - self.check_param_bool(obj, target_elt, 'disabled') - self.check_param_equal(obj, target_elt, 'interface', default='wan') - self.check_param_equal(obj, target_elt, 'local_port', default=1194) - self.check_param_equal(obj, target_elt, 'protocol', default='UDP4') - if 'tls' in obj['mode']: - self.check_param_equal(obj, target_elt, 'tls') - self.check_param_equal(obj, target_elt, 'tls') - self.check_param_equal(obj, target_elt, 'tls_type') - self.assert_xml_elt_equal(target_elt, 'caref', self.caref(obj['ca'])) - if 'crl' in obj: - self.assert_xml_elt_equal(target_elt, 'crlref', self.crlref(obj['crl'])) - if 'cert' in obj: - self.assert_xml_elt_equal(target_elt, 'certref', self.certref(obj['cert'])) - self.check_param_equal(obj, target_elt, 'cert_depth', default=1) + if "shared_key" in obj and obj["shared_key"] == "generate": + obj["shared_key"] = TLSKEY + if "tls" in obj and obj["tls"] == "generate": + obj["tls"] = TLSKEY + obj["tls_type"] = "auth" + + self.check_param_equal(obj, target_elt, "name", xml_field="description") + self.check_param_equal(obj, target_elt, "custom_options") + self.check_param_equal(obj, target_elt, "mode", default="ptp_tls") + if obj["mode"] == "server_tls_user": + self.check_list_param_equal(obj, target_elt, "authmode") + if obj["mode"] == "p2p_shared_key": + self.check_param_equal(obj, target_elt, "shared_key") + self.check_param_equal(obj, target_elt, "dev_mode", default="tun") + self.check_param_bool(obj, target_elt, "disabled") + self.check_param_equal(obj, target_elt, "interface", default="wan") + self.check_param_equal(obj, target_elt, "local_port", default=1194) + self.check_param_equal(obj, target_elt, "protocol", default="UDP4") + if "tls" in obj["mode"]: + self.check_param_equal(obj, target_elt, "tls") + self.check_param_equal(obj, target_elt, "tls") + self.check_param_equal(obj, target_elt, "tls_type") + self.assert_xml_elt_equal(target_elt, "caref", self.caref(obj["ca"])) + if "crl" in obj: + self.assert_xml_elt_equal(target_elt, "crlref", self.crlref(obj["crl"])) + if "cert" in obj: + self.assert_xml_elt_equal( + target_elt, "certref", self.certref(obj["cert"]) + ) + self.check_param_equal(obj, target_elt, "cert_depth", default=1) else: - self.assert_not_find_xml_elt('tls') - self.assert_not_find_xml_elt('tls_type') - self.check_param_bool(obj, target_elt, 'strictusercn') - self.check_param_equal(obj, target_elt, 'dh_length', default=2048) - self.check_param_equal(obj, target_elt, 'ecdh_curve', default='none') - self.check_param_equal(obj, target_elt, 'data_ciphers_fallback', default='AES-256-CBC') - self.check_param_equal(obj, target_elt, 'data_ciphers', default='AES-256-GCM,AES-128-GCM,CHACHA20-POLY1305') - self.check_param_bool(obj, target_elt, 'ncp_enable', default=True, value_true='enabled') - self.check_param_equal(obj, target_elt, 'digest', default='SHA256') - self.check_param_equal(obj, target_elt, 'ecdh_curve', default='none') - self.check_param_equal(obj, target_elt, 'allow_compression', default='no') - self.check_param_equal(obj, target_elt, 'compression', default=None) - self.check_param_bool(obj, target_elt, 'compression_push', default=False, value_true='yes') - self.check_param_equal(obj, target_elt, 'ecdh_curve', default='none') - self.check_param_equal(obj, target_elt, 'tunnel_network') - self.check_param_equal(obj, target_elt, 'tunnel_networkv6') - self.check_param_equal(obj, target_elt, 'local_network') - self.check_param_equal(obj, target_elt, 'local_networkv6') - self.check_param_equal(obj, target_elt, 'remote_network') - self.check_param_equal(obj, target_elt, 'remote_networkv6') - self.check_param_bool(obj, target_elt, 'gwredir', default=False, value_true='yes') - self.check_param_bool(obj, target_elt, 'gwredir6', default=False, value_true='yes') - self.check_param_equal(obj, target_elt, 'maxclients') + self.assert_not_find_xml_elt("tls") + self.assert_not_find_xml_elt("tls_type") + self.check_param_bool(obj, target_elt, "strictusercn") + self.check_param_equal(obj, target_elt, "dh_length", default=2048) + self.check_param_equal(obj, target_elt, "ecdh_curve", default="none") + self.check_param_equal( + obj, target_elt, "data_ciphers_fallback", default="AES-256-CBC" + ) + self.check_param_equal( + obj, + target_elt, + "data_ciphers", + default="AES-256-GCM,AES-128-GCM,CHACHA20-POLY1305", + ) + self.check_param_bool( + obj, target_elt, "ncp_enable", default=True, value_true="enabled" + ) + self.check_param_equal(obj, target_elt, "digest", default="SHA256") + self.check_param_equal(obj, target_elt, "ecdh_curve", default="none") + self.check_param_equal(obj, target_elt, "allow_compression", default="no") + self.check_param_equal(obj, target_elt, "compression", default=None) + self.check_param_bool( + obj, target_elt, "compression_push", default=False, value_true="yes" + ) + self.check_param_equal(obj, target_elt, "ecdh_curve", default="none") + self.check_param_equal(obj, target_elt, "tunnel_network") + self.check_param_equal(obj, target_elt, "tunnel_networkv6") + self.check_param_equal(obj, target_elt, "local_network") + self.check_param_equal(obj, target_elt, "local_networkv6") + self.check_param_equal(obj, target_elt, "remote_network") + self.check_param_equal(obj, target_elt, "remote_networkv6") + self.check_param_bool( + obj, target_elt, "gwredir", default=False, value_true="yes" + ) + self.check_param_bool( + obj, target_elt, "gwredir6", default=False, value_true="yes" + ) + self.check_param_equal(obj, target_elt, "maxclients") ############## # tests # def test_openvpn_server_create(self): - """ test creation of a new OpenVPN server """ - obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA', local_port=1196) - self.do_module_test(obj, command="create openvpn_server 'ovpns3', description='ovpns3'") + """test creation of a new OpenVPN server""" + obj = dict(name="ovpns3", mode="p2p_tls", ca="OpenVPN CA", local_port=1196) + self.do_module_test( + obj, command="create openvpn_server 'ovpns3', description='ovpns3'" + ) def test_openvpn_server_create_generate(self): - """ test creation of a new OpenVPN server """ - obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA', local_port=1196, tls='generate') - self.do_module_test(obj, command="create openvpn_server 'ovpns3', description='ovpns3'") + """test creation of a new OpenVPN server""" + obj = dict( + name="ovpns3", + mode="p2p_tls", + ca="OpenVPN CA", + local_port=1196, + tls="generate", + ) + self.do_module_test( + obj, command="create openvpn_server 'ovpns3', description='ovpns3'" + ) def test_openvpn_server_delete(self): - """ test deletion of a OpenVPN server """ - obj = dict(name='ovpns2') + """test deletion of a OpenVPN server""" + obj = dict(name="ovpns2") self.do_module_test(obj, command="delete openvpn_server 'ovpns2'", delete=True) def test_openvpn_server_update_noop(self): - """ test not updating a OpenVPN server """ - obj = dict(name='ovpns2', mode='p2p_tls', ca='OpenVPN CA', local_port=1195, tls=TLSKEY, tls_type='auth') + """test not updating a OpenVPN server""" + obj = dict( + name="ovpns2", + mode="p2p_tls", + ca="OpenVPN CA", + local_port=1195, + tls=TLSKEY, + tls_type="auth", + ) self.do_module_test(obj, changed=False) def test_openvpn_server_update_network(self): - """ test updating network of a OpenVPN server """ - obj = dict(name='ovpns2', mode='p2p_tls', ca='OpenVPN CA', local_port=1195, tls=TLSKEY, tls_type='auth', tunnel_network='10.10.10.10/24') + """test updating network of a OpenVPN server""" + obj = dict( + name="ovpns2", + mode="p2p_tls", + ca="OpenVPN CA", + local_port=1195, + tls=TLSKEY, + tls_type="auth", + tunnel_network="10.10.10.10/24", + ) self.do_module_test(obj, command="update openvpn_server 'ovpns2' set ") ############## # misc # def test_create_openvpn_server_duplicate_port(self): - """ test creation of a new OpenVPN server with duplicate port """ - obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA') - self.do_module_test(obj, failed=True, msg='The specified local_port (1194) is in use by vpn ID 1') + """test creation of a new OpenVPN server with duplicate port""" + obj = dict(name="ovpns3", mode="p2p_tls", ca="OpenVPN CA") + self.do_module_test( + obj, + failed=True, + msg="The specified local_port (1194) is in use by vpn ID 1", + ) def test_create_openvpn_server_invalid_certificate(self): - """ test creation of a new OpenVPN server with invalid certificate """ - obj = dict(name='ovpns2', mode='p2p_tls', ca='OpenVPN CA', cert='blah') - self.do_module_test(obj, failed=True, msg='blah is not a valid certificate') + """test creation of a new OpenVPN server with invalid certificate""" + obj = dict(name="ovpns2", mode="p2p_tls", ca="OpenVPN CA", cert="blah") + self.do_module_test(obj, failed=True, msg="blah is not a valid certificate") def test_delete_nonexistent_openvpn_server(self): - """ test deletion of an nonexistent OpenVPN server """ - obj = dict(name='novpn') - self.do_module_test(obj, commmand=None, state='absent', changed=False) + """test deletion of an nonexistent OpenVPN server""" + obj = dict(name="novpn") + self.do_module_test(obj, commmand=None, state="absent", changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_route.py b/tests/unit/plugins/modules/test_pfsense_route.py index 07376215..9600b2ff 100644 --- a/tests/unit/plugins/modules/test_pfsense_route.py +++ b/tests/unit/plugins/modules/test_pfsense_route.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,49 +12,54 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_route -from ansible_collections.pfsensible.core.plugins.module_utils.route import PFSenseRouteModule +from ansible_collections.pfsensible.core.plugins.module_utils.route import ( + PFSenseRouteModule, +) from .pfsense_module import TestPFSenseModule -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) class TestPFSenseRouteModule(TestPFSenseModule): - module = pfsense_route def __init__(self, *args, **kwargs): super(TestPFSenseRouteModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_route_config.xml' + self.config_file = "pfsense_route_config.xml" self.pfmodule = PFSenseRouteModule def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseRouteModule, self).setUp() - self.mock_run_command = patch('ansible.module_utils.basic.AnsibleModule.run_command') + self.mock_run_command = patch( + "ansible.module_utils.basic.AnsibleModule.run_command" + ) self.run_command = self.mock_run_command.start() - self.run_command.return_value = (0, '', '') + self.run_command.return_value = (0, "", "") def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseRouteModule, self).tearDown() self.run_command.stop() def check_target_elt(self, obj, target_elt): - """ test the xml definition """ + """test the xml definition""" - self.check_param_equal_or_not_find(obj, target_elt, 'disabled') - self.check_param_equal(obj, target_elt, 'gateway') - self.check_param_equal(obj, target_elt, 'network') + self.check_param_equal_or_not_find(obj, target_elt, "disabled") + self.check_param_equal(obj, target_elt, "gateway") + self.check_param_equal(obj, target_elt, "network") def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - root_elt = self.assert_find_xml_elt(self.xml_result, 'staticroutes') + """get the generated xml definition""" + root_elt = self.assert_find_xml_elt(self.xml_result, "staticroutes") for item in root_elt: - name_elt = item.find('descr') - if name_elt is not None and name_elt.text == obj['descr']: + name_elt = item.find("descr") + if name_elt is not None and name_elt.text == obj["descr"]: return item return None @@ -62,72 +68,72 @@ def get_target_elt(self, obj, absent=False, module_result=None): # tests # def test_route_create(self): - """ test """ - obj = dict(descr='test_route', network='1.2.3.4/24', gateway='GW_LAN') + """test""" + obj = dict(descr="test_route", network="1.2.3.4/24", gateway="GW_LAN") command = "create route 'test_route', network='1.2.3.4/24', gateway='GW_LAN'" self.do_module_test(obj, command=command) def test_route_create_invalid_gw(self): - """ test """ - obj = dict(descr='test_route', network='1.2.3.4/24', gateway='GW_INVALID') + """test""" + obj = dict(descr="test_route", network="1.2.3.4/24", gateway="GW_INVALID") msg = "The gateway GW_INVALID does not exist" self.do_module_test(obj, msg=msg, failed=True) def test_route_create_invalid_ip(self): - """ test """ - obj = dict(descr='test_route', network='2001::1', gateway='GW_LAN') + """test""" + obj = dict(descr="test_route", network="2001::1", gateway="GW_LAN") msg = 'The gateway "192.168.1.1" is a different Address Family than network "2001::1".' self.do_module_test(obj, msg=msg, failed=True) def test_route_create_invalid_ip2(self): - """ test """ - obj = dict(descr='test_route', network='1.2.3.4', gateway='GW_LAN_V6') + """test""" + obj = dict(descr="test_route", network="1.2.3.4", gateway="GW_LAN_V6") msg = 'The gateway "2002::1" is a different Address Family than network "1.2.3.4".' self.do_module_test(obj, msg=msg, failed=True) def test_route_create_invalid_alias(self): - """ test """ - obj = dict(descr='test_route', network='invalid_alias', gateway='GW_LAN') - msg = 'A valid IPv4 or IPv6 destination network or alias must be specified.' + """test""" + obj = dict(descr="test_route", network="invalid_alias", gateway="GW_LAN") + msg = "A valid IPv4 or IPv6 destination network or alias must be specified." self.do_module_test(obj, msg=msg, failed=True) def test_route_update_noop(self): - """ test """ - obj = dict(descr='GW_WAN route', network='10.3.0.0/16', gateway='GW_WAN') + """test""" + obj = dict(descr="GW_WAN route", network="10.3.0.0/16", gateway="GW_WAN") self.do_module_test(obj, changed=False) def test_route_update_network(self): - """ test """ - obj = dict(descr='GW_WAN route', network='10.4.0.0/16', gateway='GW_WAN') + """test""" + obj = dict(descr="GW_WAN route", network="10.4.0.0/16", gateway="GW_WAN") command = "update route 'GW_WAN route' set network='10.4.0.0/16'" self.do_module_test(obj, command=command) def test_route_update_gateway(self): - """ test """ - obj = dict(descr='GW_WAN route', network='10.3.0.0/16', gateway='GW_LAN') + """test""" + obj = dict(descr="GW_WAN route", network="10.3.0.0/16", gateway="GW_LAN") command = "update route 'GW_WAN route' set gateway='GW_LAN'" self.do_module_test(obj, command=command) def test_route_delete(self): - """ test """ - obj = dict(descr='GW_WAN route') + """test""" + obj = dict(descr="GW_WAN route") command = "delete route 'GW_WAN route'" self.do_module_test(obj, command=command, delete=True) def test_route_delete_alias(self): - """ test """ - obj = dict(descr='GW_WAN alias') + """test""" + obj = dict(descr="GW_WAN alias") command = "delete route 'GW_WAN alias'" self.do_module_test(obj, command=command, delete=True) def test_route_create_dhcp(self): - """ test """ - obj = dict(descr='test_route', network='1.2.3.4/24', gateway='VPN_DHCP') + """test""" + obj = dict(descr="test_route", network="1.2.3.4/24", gateway="VPN_DHCP") command = "create route 'test_route', network='1.2.3.4/24', gateway='VPN_DHCP'" self.do_module_test(obj, command=command) def test_route_create_dhcp6(self): - """ test """ - obj = dict(descr='test_route', network='2001::/56', gateway='VPN_DHCP6') + """test""" + obj = dict(descr="test_route", network="2001::/56", gateway="VPN_DHCP6") command = "create route 'test_route', network='2001::/56', gateway='VPN_DHCP6'" self.do_module_test(obj, command=command) diff --git a/tests/unit/plugins/modules/test_pfsense_rule.py b/tests/unit/plugins/modules/test_pfsense_rule.py index f6dbc2ce..e0f9d672 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule.py +++ b/tests/unit/plugins/modules/test_pfsense_rule.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -17,206 +18,221 @@ class TestPFSenseRuleModule(TestPFSenseModule): - module = pfsense_rule def __init__(self, *args, **kwargs): super(TestPFSenseRuleModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_rule_config.xml' + self.config_file = "pfsense_rule_config.xml" self.pfmodule = PFSenseRuleModule @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def parse_address(self, addr): - """ return address parsed in dict """ + """return address parsed in dict""" if is_ipv6_address(addr) or is_ipv6_network(addr): parts = [addr] else: parts = addr.split(':') res = {} - if parts[0][0] == '!': - res['not'] = None + if parts[0][0] == "!": + res["not"] = None parts[0] = parts[0][1:] - if parts[0] == 'any': - res['any'] = None - elif parts[0] == '(self)': - res['network'] = '(self)' - elif parts[0] == 'NET': - res['network'] = self.unalias_interface(parts[1]) + if parts[0] == "any": + res["any"] = None + elif parts[0] == "(self)": + res["network"] = "(self)" + elif parts[0] == "NET": + res["network"] = self.unalias_interface(parts[1]) del parts[1] - elif parts[0] == 'IP': - res['network'] = self.unalias_interface(parts[1]) + 'ip' + elif parts[0] == "IP": + res["network"] = self.unalias_interface(parts[1]) + "ip" del parts[1] - elif parts[0] in ['lan', 'lan', 'vpn', 'vt1', 'lan_100']: - res['network'] = self.unalias_interface(parts[0]) + elif parts[0] in ["lan", "lan", "vpn", "vt1", "lan_100"]: + res["network"] = self.unalias_interface(parts[0]) else: - res['address'] = parts[0] + res["address"] = parts[0] if len(parts) > 1: - res['port'] = parts[1] + res["port"] = parts[1] return res def check_rule_elt_addr(self, rule, rule_elt, addr): - """ test the addresses definition of rule """ + """test the addresses definition of rule""" addr_dict = self.parse_address(rule[addr]) addr_elt = self.assert_find_xml_elt(rule_elt, addr) for key, value in addr_dict.items(): self.assert_xml_elt_equal(addr_elt, key, value) - if 'any' in addr_dict: - self.assert_not_find_xml_elt(addr_elt, 'address') - self.assert_not_find_xml_elt(addr_elt, 'network') - if 'network' in addr_dict: - self.assert_not_find_xml_elt(addr_elt, 'address') - self.assert_not_find_xml_elt(addr_elt, 'any') - if 'address' in addr_dict: - self.assert_not_find_xml_elt(addr_elt, 'network') - self.assert_not_find_xml_elt(addr_elt, 'any') - - if 'not' not in addr_dict: - self.assert_not_find_xml_elt(addr_elt, 'not') + if "any" in addr_dict: + self.assert_not_find_xml_elt(addr_elt, "address") + self.assert_not_find_xml_elt(addr_elt, "network") + if "network" in addr_dict: + self.assert_not_find_xml_elt(addr_elt, "address") + self.assert_not_find_xml_elt(addr_elt, "any") + if "address" in addr_dict: + self.assert_not_find_xml_elt(addr_elt, "network") + self.assert_not_find_xml_elt(addr_elt, "any") + + if "not" not in addr_dict: + self.assert_not_find_xml_elt(addr_elt, "not") def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ - obj['interface'] = self.unalias_interface(obj['interface']) - if 'floating' in obj and obj['floating'] == 'yes': - return self.assert_has_xml_tag('filter', dict(descr=obj['name'], floating='yes'), absent=absent) - return self.assert_has_xml_tag('filter', dict(descr=obj['name'], interface=obj['interface']), absent=absent) + """return target elt from XML""" + obj["interface"] = self.unalias_interface(obj["interface"]) + if "floating" in obj and obj["floating"] == "yes": + return self.assert_has_xml_tag( + "filter", dict(descr=obj["name"], floating="yes"), absent=absent + ) + return self.assert_has_xml_tag( + "filter", dict(descr=obj["name"], interface=obj["interface"]), absent=absent + ) def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ + """check XML definition of target elt""" # checking source address and ports - self.check_rule_elt_addr(obj, target_elt, 'source') + self.check_rule_elt_addr(obj, target_elt, "source") # checking destination address and ports - self.check_rule_elt_addr(obj, target_elt, 'destination') + self.check_rule_elt_addr(obj, target_elt, "destination") # checking log option - if 'log' in obj and obj['log'] == 'yes': - self.assert_xml_elt_is_none_or_empty(target_elt, 'log') - elif 'log' not in obj or obj['log'] == 'no': - self.assert_not_find_xml_elt(target_elt, 'log') + if "log" in obj and obj["log"] == "yes": + self.assert_xml_elt_is_none_or_empty(target_elt, "log") + elif "log" not in obj or obj["log"] == "no": + self.assert_not_find_xml_elt(target_elt, "log") # checking action option - if 'action' in obj: - action = obj['action'] + if "action" in obj: + action = obj["action"] else: - action = 'pass' - self.assert_xml_elt_equal(target_elt, 'type', action) + action = "pass" + self.assert_xml_elt_equal(target_elt, "type", action) # checking floating option - if 'floating' in obj and obj['floating'] == 'yes': - self.assert_xml_elt_equal(target_elt, 'floating', 'yes') - if 'quick' in obj and obj['quick'] == 'yes': - self.assert_xml_elt_equal(target_elt, 'quick', 'yes') + if "floating" in obj and obj["floating"] == "yes": + self.assert_xml_elt_equal(target_elt, "floating", "yes") + if "quick" in obj and obj["quick"] == "yes": + self.assert_xml_elt_equal(target_elt, "quick", "yes") else: - self.assert_not_find_xml_elt(target_elt, 'quick') + self.assert_not_find_xml_elt(target_elt, "quick") - elif 'floating' not in obj or obj['floating'] == 'no': - self.assert_not_find_xml_elt(target_elt, 'floating') - self.assert_not_find_xml_elt(target_elt, 'quick') + elif "floating" not in obj or obj["floating"] == "no": + self.assert_not_find_xml_elt(target_elt, "floating") + self.assert_not_find_xml_elt(target_elt, "quick") # checking direction option - self.check_param_equal_or_not_find(obj, target_elt, 'direction') + self.check_param_equal_or_not_find(obj, target_elt, "direction") # checking default queue option - self.check_param_equal_or_not_find(obj, target_elt, 'queue', 'defaultqueue') + self.check_param_equal_or_not_find(obj, target_elt, "queue", "defaultqueue") # checking acknowledge queue option - self.check_param_equal_or_not_find(obj, target_elt, 'ackqueue') + self.check_param_equal_or_not_find(obj, target_elt, "ackqueue") # limiters - self.check_param_equal_or_not_find(obj, target_elt, 'in_queue', 'dnpipe') - self.check_param_equal_or_not_find(obj, target_elt, 'out_queue', 'pdnpipe') + self.check_param_equal_or_not_find(obj, target_elt, "in_queue", "dnpipe") + self.check_param_equal_or_not_find(obj, target_elt, "out_queue", "pdnpipe") # schedule - self.check_param_equal_or_not_find(obj, target_elt, 'sched') + self.check_param_equal_or_not_find(obj, target_elt, "sched") # checking ipprotocol option - if 'ipprotocol' in obj: - action = obj['ipprotocol'] + if "ipprotocol" in obj: + action = obj["ipprotocol"] else: - action = 'inet' - self.assert_xml_elt_equal(target_elt, 'ipprotocol', action) + action = "inet" + self.assert_xml_elt_equal(target_elt, "ipprotocol", action) # checking protocol option - if 'protocol' in obj and obj['protocol'] != 'any': - self.assert_xml_elt_equal(target_elt, 'protocol', obj['protocol']) + if "protocol" in obj and obj["protocol"] != "any": + self.assert_xml_elt_equal(target_elt, "protocol", obj["protocol"]) else: - self.assert_not_find_xml_elt(target_elt, 'protocol') + self.assert_not_find_xml_elt(target_elt, "protocol") # checking tcpflags_any option - if 'tcpflags_any' in obj and obj['tcpflags_any'] == 'yes': - self.assert_xml_elt_is_none_or_empty(target_elt, 'tcpflags_any') - elif 'tcpflags_any' not in obj or obj['tcpflags_any'] == 'no': - self.assert_not_find_xml_elt(target_elt, 'tcpflags_any') + if "tcpflags_any" in obj and obj["tcpflags_any"] == "yes": + self.assert_xml_elt_is_none_or_empty(target_elt, "tcpflags_any") + elif "tcpflags_any" not in obj or obj["tcpflags_any"] == "no": + self.assert_not_find_xml_elt(target_elt, "tcpflags_any") # checking statetype option - if 'statetype' in obj and obj['statetype'] != 'keep state': - statetype = obj['statetype'] + if "statetype" in obj and obj["statetype"] != "keep state": + statetype = obj["statetype"] else: - statetype = 'keep state' - self.assert_xml_elt_equal(target_elt, 'statetype', statetype) + statetype = "keep state" + self.assert_xml_elt_equal(target_elt, "statetype", statetype) # checking disabled option - if 'disabled' in obj and obj['disabled'] == 'yes': - self.assert_xml_elt_is_none_or_empty(target_elt, 'disabled') - elif 'disabled' not in obj or obj['disabled'] == 'no': - self.assert_not_find_xml_elt(target_elt, 'disabled') + if "disabled" in obj and obj["disabled"] == "yes": + self.assert_xml_elt_is_none_or_empty(target_elt, "disabled") + elif "disabled" not in obj or obj["disabled"] == "no": + self.assert_not_find_xml_elt(target_elt, "disabled") # checking gateway option - if 'gateway' in obj and obj['gateway'] != 'default': - self.assert_xml_elt_equal(target_elt, 'gateway', obj['gateway']) + if "gateway" in obj and obj["gateway"] != "default": + self.assert_xml_elt_equal(target_elt, "gateway", obj["gateway"]) else: - self.assert_not_find_xml_elt(target_elt, 'gateway') + self.assert_not_find_xml_elt(target_elt, "gateway") # checking tracker - if 'tracker' in obj: - self.assert_xml_elt_equal(target_elt, 'tracker', obj['tracker']) + if "tracker" in obj: + self.assert_xml_elt_equal(target_elt, "tracker", obj["tracker"]) # checking icmptype - if 'icmptype' in obj: - self.assert_xml_elt_equal(target_elt, 'icmptype', obj['icmptype']) + if "icmptype" in obj: + self.assert_xml_elt_equal(target_elt, "icmptype", obj["icmptype"]) def check_rule_idx(self, rule, target_idx): - """ test the xml position of rule """ - floating = 'floating' in rule and rule['floating'] == 'yes' - rule['interface'] = self.unalias_interface(rule['interface']) - rules_elt = self.assert_find_xml_elt(self.xml_result, 'filter') + """test the xml position of rule""" + floating = "floating" in rule and rule["floating"] == "yes" + rule["interface"] = self.unalias_interface(rule["interface"]) + rules_elt = self.assert_find_xml_elt(self.xml_result, "filter") idx = -1 for rule_elt in rules_elt: - interface_elt = rule_elt.find('interface') - floating_elt = rule_elt.find('floating') - floating_rule = floating_elt is not None and floating_elt.text == 'yes' + interface_elt = rule_elt.find("interface") + floating_elt = rule_elt.find("floating") + floating_rule = floating_elt is not None and floating_elt.text == "yes" if floating and not floating_rule: continue if not floating: - if floating_rule or interface_elt is None or interface_elt.text is None or interface_elt.text != rule['interface']: + if ( + floating_rule + or interface_elt is None + or interface_elt.text is None + or interface_elt.text != rule["interface"] + ): continue idx += 1 - descr_elt = rule_elt.find('descr') + descr_elt = rule_elt.find("descr") self.assertIsNotNone(descr_elt) self.assertIsNotNone(descr_elt.text) - if descr_elt.text == rule['name']: + if descr_elt.text == rule["name"]: self.assertEqual(idx, target_idx) return - self.fail('rule not found ' + str(idx)) + self.fail("rule not found " + str(idx)) def check_separator_idx(self, interface, sep_name, expected_idx): - """ test the logical position of separator """ - filter_elt = self.assert_find_xml_elt(self.xml_result, 'filter') - separator_elt = self.assert_find_xml_elt(filter_elt, 'separator') + """test the logical position of separator""" + filter_elt = self.assert_find_xml_elt(self.xml_result, "filter") + separator_elt = self.assert_find_xml_elt(filter_elt, "separator") iface_elt = self.assert_find_xml_elt(separator_elt, interface) for separator in iface_elt: - text_elt = separator.find('text') + text_elt = separator.find("text") if text_elt is not None and text_elt.text == sep_name: - row_elt = self.assert_find_xml_elt(separator, 'row') - idx = int(row_elt.text.replace('fr', '')) + row_elt = self.assert_find_xml_elt(separator, "row") + idx = int(row_elt.text.replace("fr", "")) if idx != expected_idx: - self.fail('Idx of separator ' + sep_name + ' if wrong: ' + str(idx) + ', expected: ' + str(expected_idx)) + self.fail( + "Idx of separator " + + sep_name + + " if wrong: " + + str(idx) + + ", expected: " + + str(expected_idx) + ) return - self.fail('Separator ' + sep_name + 'not found on interface ' + interface) + self.fail("Separator " + sep_name + "not found on interface " + interface) diff --git a/tests/unit/plugins/modules/test_pfsense_rule_create.py b/tests/unit/plugins/modules/test_pfsense_rule_create.py index f300e07f..410d2b34 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule_create.py +++ b/tests/unit/plugins/modules/test_pfsense_rule_create.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -14,318 +15,596 @@ class TestPFSenseRuleCreateModule(TestPFSenseRuleModule): - ############################ # rule creation tests # def test_rule_create_one_rule(self): - """ test creation of a new rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan') + """test creation of a new rule""" + obj = dict(name="one_rule", source="any", destination="any", interface="lan") command = "create rule 'one_rule' on 'lan', source='any', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_log(self): - """ test creation of a new rule with logging """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', log='yes') - command = "create rule 'one_rule' on 'lan', source='any', destination='any', log=True" + """test creation of a new rule with logging""" + obj = dict( + name="one_rule", source="any", destination="any", interface="lan", log="yes" + ) + command = ( + "create rule 'one_rule' on 'lan', source='any', destination='any', log=True" + ) self.do_module_test(obj, command=command) def test_rule_create_nolog(self): - """ test creation of a new rule without logging """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', log='no') + """test creation of a new rule without logging""" + obj = dict( + name="one_rule", source="any", destination="any", interface="lan", log="no" + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_pass(self): - """ test creation of a new rule explictly passing """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', action='pass') + """test creation of a new rule explictly passing""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + action="pass", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_block(self): - """ test creation of a new rule blocking """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', action='block') + """test creation of a new rule blocking""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + action="block", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', action='block'" self.do_module_test(obj, command=command) def test_rule_create_reject(self): - """ test creation of a new rule rejecting """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', action='reject') + """test creation of a new rule rejecting""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + action="reject", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', action='reject'" self.do_module_test(obj, command=command) def test_rule_create_disabled(self): - """ test creation of a new disabled rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', disabled=True) + """test creation of a new disabled rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + disabled=True, + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', disabled=True" self.do_module_test(obj, command=command) def test_rule_create_floating(self): - """ test creation of a new floating rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', floating='yes', direction='any') + """test creation of a new floating rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + floating="yes", + direction="any", + ) command = "create rule 'one_rule' on 'floating(lan)', source='any', destination='any', direction='any'" self.do_module_test(obj, command=command) def test_rule_create_floating_any(self): - """ test creation of a new floating rule with any interface """ - obj = dict(name='one_rule', source='any', destination='any', interface='any', floating='yes', direction='any') + """test creation of a new floating rule with any interface""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="any", + floating="yes", + direction="any", + ) command = "create rule 'one_rule' on 'floating(any)', source='any', destination='any', direction='any'" def test_rule_create_non_floating_any(self): - """ test creation of a new rule with any interface """ - obj = dict(name='one_rule', source='any', destination='any', interface='any', floating='no', direction='any') + """test creation of a new rule with any interface""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="any", + floating="no", + direction="any", + ) msg = "any is not a valid interface" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_floating_quick(self): - """ test creation of a new floating rule with quick match """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', floating='yes', direction='any', quick='yes') + """test creation of a new floating rule with quick match""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + floating="yes", + direction="any", + quick="yes", + ) command = "create rule 'one_rule' on 'floating(lan)', source='any', destination='any', direction='any', quick=True" self.do_module_test(obj, command=command) def test_rule_create_nofloating(self): - """ test creation of a new non-floating rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', floating='no') + """test creation of a new non-floating rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + floating="no", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_floating_interfaces(self): - """ test creation of a floating rule on three interfaces """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan,wan,vt1', floating='yes', direction='any') + """test creation of a floating rule on three interfaces""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan,wan,vt1", + floating="yes", + direction="any", + ) command = "create rule 'one_rule' on 'floating(lan,wan,vt1)', source='any', destination='any', direction='any'" self.do_module_test(obj, command=command) def test_rule_create_inet46(self): - """ test creation of a new rule using ipv4 and ipv6 """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', ipprotocol='inet46') + """test creation of a new rule using ipv4 and ipv6""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + ipprotocol="inet46", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', ipprotocol='inet46'" self.do_module_test(obj, command=command) def test_rule_create_inet6(self): - """ test creation of a new rule using ipv6 """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', ipprotocol='inet6') + """test creation of a new rule using ipv6""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + ipprotocol="inet6", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', ipprotocol='inet6'" self.do_module_test(obj, command=command) def test_rule_create_tcp(self): - """ test creation of a new rule for tcp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='tcp') + """test creation of a new rule for tcp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_udp(self): - """ test creation of a new rule for udp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='udp') + """test creation of a new rule for udp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="udp", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', protocol='udp'" self.do_module_test(obj, command=command) def test_rule_create_tcp_udp(self): - """ test creation of a new rule for tcp/udp protocols """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='tcp/udp') + """test creation of a new rule for tcp/udp protocols""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="tcp/udp", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', protocol='tcp/udp'" self.do_module_test(obj, command=command) def test_rule_create_icmp(self): - """ test creation of a new rule for icmp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='icmp') + """test creation of a new rule for icmp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="icmp", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', protocol='icmp'" self.do_module_test(obj, command=command) def test_rule_create_icmp_redir(self): - """ test creation of a new rule for icmp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='icmp', icmptype='redir', action='block') + """test creation of a new rule for icmp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="icmp", + icmptype="redir", + action="block", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', protocol='icmp', icmptype='redir', action='block'" self.do_module_test(obj, command=command) def test_rule_create_icmp_invalid_inet(self): - """ test creation of a new rule for icmp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='icmp', icmptype='neighbradv') - msg = 'ICMP types neighbradv are invalid with IP type inet' + """test creation of a new rule for icmp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="icmp", + icmptype="neighbradv", + ) + msg = "ICMP types neighbradv are invalid with IP type inet" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_icmp_invalid_inet6(self): - """ test creation of a new rule for icmp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='icmp', ipprotocol='inet6', icmptype='trace') - msg = 'ICMP types trace are invalid with IP type inet6' + """test creation of a new rule for icmp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="icmp", + ipprotocol="inet6", + icmptype="trace", + ) + msg = "ICMP types trace are invalid with IP type inet6" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_icmp_invalid_inet46(self): - """ test creation of a new rule for icmp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='icmp', ipprotocol='inet46', icmptype='trace') - msg = 'ICMP types trace are invalid with IP type inet46' + """test creation of a new rule for icmp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="icmp", + ipprotocol="inet46", + icmptype="trace", + ) + msg = "ICMP types trace are invalid with IP type inet46" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_icmp_invalid_empty(self): - """ test creation of a new rule for icmp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='icmp', icmptype='') - msg = 'You must specify at least one icmptype or any for all of them' + """test creation of a new rule for icmp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="icmp", + icmptype="", + ) + msg = "You must specify at least one icmptype or any for all of them" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_esp(self): - """ test creation of a new rule for esp protocol """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='esp') + """test creation of a new rule for esp protocol""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="esp", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', protocol='esp'" self.do_module_test(obj, command=command) def test_rule_create_protocol_any(self): - """ test creation of a new rule for (self) """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', protocol='any') + """test creation of a new rule for (self)""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + protocol="any", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_tcpflags_any(self): - """ test creation of a new rule with tcpflags_any """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', tcpflags_any='yes') + """test creation of a new rule with tcpflags_any""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + tcpflags_any="yes", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', tcpflags_any=True" self.do_module_test(obj, command=command) def test_rule_create_state_keep(self): - """ test creation of a new rule with explicit keep state """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', statetype='keep state') + """test creation of a new rule with explicit keep state""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + statetype="keep state", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_state_sloppy(self): - """ test creation of a new rule with sloppy state """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', statetype='sloppy state') + """test creation of a new rule with sloppy state""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + statetype="sloppy state", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', statetype='sloppy state'" self.do_module_test(obj, command=command) def test_rule_create_state_synproxy(self): - """ test creation of a new rule with synproxy state """ + """test creation of a new rule with synproxy state""" # todo: synproxy is only valid with tcp - obj = dict(name='one_rule', source='any', destination='any', interface='lan', statetype='synproxy state') + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + statetype="synproxy state", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', statetype='synproxy state'" self.do_module_test(obj, command=command) def test_rule_create_state_none(self): - """ test creation of a new rule with no state """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', statetype='none') + """test creation of a new rule with no state""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + statetype="none", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', statetype='none'" self.do_module_test(obj, command=command) def test_rule_create_state_invalid(self): - """ test creation of a new rule with invalid state """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', statetype='acme state') + """test creation of a new rule with invalid state""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + statetype="acme state", + ) msg = "value of statetype must be one of: keep state, sloppy state, synproxy state, none, got: acme state" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_after(self): - """ test creation of a new rule after another """ - obj = dict(name='one_rule', source='any', destination='any', interface='vpn', after='admin_bypass') + """test creation of a new rule after another""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="vpn", + after="admin_bypass", + ) command = "create rule 'one_rule' on 'vpn', source='any', destination='any', after='admin_bypass'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 13) def test_rule_create_after_top(self): - """ test creation of a new rule at top """ - obj = dict(name='one_rule', source='any', destination='any', interface='wan', after='top') + """test creation of a new rule at top""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="wan", + after="top", + ) command = "create rule 'one_rule' on 'wan', source='any', destination='any', after='top'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) def test_rule_create_after_invalid(self): - """ test creation of a new rule after an invalid rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', after='admin_bypass') + """test creation of a new rule after an invalid rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + after="admin_bypass", + ) msg = "Failed to insert after rule=admin_bypass interface=lan" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_before(self): - """ test creation of a new rule before another """ - obj = dict(name='one_rule', source='any', destination='any', interface='vpn', before='admin_bypass') + """test creation of a new rule before another""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="vpn", + before="admin_bypass", + ) command = "create rule 'one_rule' on 'vpn', source='any', destination='any', before='admin_bypass'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 12) def test_rule_create_before_bottom(self): - """ test creation of a new rule at bottom """ - obj = dict(name='one_rule', source='any', destination='any', interface='wan', before='bottom') + """test creation of a new rule at bottom""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="wan", + before="bottom", + ) command = "create rule 'one_rule' on 'wan', source='any', destination='any', before='bottom'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 4) def test_rule_create_before_bottom_default(self): - """ test creation of a new rule at bottom (default) """ - obj = dict(name='one_rule', source='any', destination='any', interface='wan', action='pass') + """test creation of a new rule at bottom (default)""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="wan", + action="pass", + ) command = "create rule 'one_rule' on 'wan', source='any', destination='any'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 4) def test_rule_create_before_invalid(self): - """ test creation of a new rule before an invalid rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', before='admin_bypass') + """test creation of a new rule before an invalid rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + before="admin_bypass", + ) msg = "Failed to insert before rule=admin_bypass interface=lan" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_source_alias(self): - """ test creation of a new rule with a valid source alias """ - obj = dict(name='one_rule', source='srv_admin', destination='any', interface='lan') - command = "create rule 'one_rule' on 'lan', source='srv_admin', destination='any'" + """test creation of a new rule with a valid source alias""" + obj = dict( + name="one_rule", source="srv_admin", destination="any", interface="lan" + ) + command = ( + "create rule 'one_rule' on 'lan', source='srv_admin', destination='any'" + ) self.do_module_test(obj, command=command) def test_rule_create_source_urltable_alias(self): - """ test creation of a new rule with a valid source urltable alias """ - obj = dict(name='one_rule', source='acme_corp', destination='any', interface='lan') - command = "create rule 'one_rule' on 'lan', source='acme_corp', destination='any'" + """test creation of a new rule with a valid source urltable alias""" + obj = dict( + name="one_rule", source="acme_corp", destination="any", interface="lan" + ) + command = ( + "create rule 'one_rule' on 'lan', source='acme_corp', destination='any'" + ) self.do_module_test(obj, command=command) def test_rule_create_source_alias_invalid(self): - """ test creation of a new rule with an invalid source alias """ - obj = dict(name='one_rule', source='acme', destination='any', interface='lan') + """test creation of a new rule with an invalid source alias""" + obj = dict(name="one_rule", source="acme", destination="any", interface="lan") msg = "Cannot parse address acme, not IP or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_invalid_ports(self): - """ test creation of a new rule with an invalid use of ports """ - obj = dict(name='one_rule', source='192.193.194.195', destination='any:22', interface='lan', protocol='icmp') + """test creation of a new rule with an invalid use of ports""" + obj = dict( + name="one_rule", + source="192.193.194.195", + destination="any:22", + interface="lan", + protocol="icmp", + ) msg = "'one_rule' on 'lan': you can't use ports on protocols other than tcp, udp or tcp/udp" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_source_ip_invalid(self): - """ test creation of a new rule with an invalid source ip """ - obj = dict(name='one_rule', source='192.193.194.195.196', destination='any', interface='lan') + """test creation of a new rule with an invalid source ip""" + obj = dict( + name="one_rule", + source="192.193.194.195.196", + destination="any", + interface="lan", + ) msg = "Cannot parse address 192.193.194.195.196, not IP or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_source_net_invalid(self): - """ test creation of a new rule with an invalid source network """ - obj = dict(name='one_rule', source='192.193.194.195/256', destination='any', interface='lan') + """test creation of a new rule with an invalid source network""" + obj = dict( + name="one_rule", + source="192.193.194.195/256", + destination="any", + interface="lan", + ) msg = "Cannot parse address 192.193.194.195/256, not IP or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_destination_alias(self): - """ test creation of a new rule with a valid destination alias """ - obj = dict(name='one_rule', source='any', destination='srv_admin', interface='lan') - command = "create rule 'one_rule' on 'lan', source='any', destination='srv_admin'" + """test creation of a new rule with a valid destination alias""" + obj = dict( + name="one_rule", source="any", destination="srv_admin", interface="lan" + ) + command = ( + "create rule 'one_rule' on 'lan', source='any', destination='srv_admin'" + ) self.do_module_test(obj, command=command) def test_rule_create_destination_alias_invalid(self): - """ test creation of a new rule with an invalid destination alias """ - obj = dict(name='one_rule', source='any', destination='acme', interface='lan') + """test creation of a new rule with an invalid destination alias""" + obj = dict(name="one_rule", source="any", destination="acme", interface="lan") msg = "Cannot parse address acme, not IP or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_destination_ip_invalid(self): - """ test creation of a new rule with an invalid destination ip """ - obj = dict(name='one_rule', source='any', destination='192.193.194.195.196', interface='lan') + """test creation of a new rule with an invalid destination ip""" + obj = dict( + name="one_rule", + source="any", + destination="192.193.194.195.196", + interface="lan", + ) msg = "Cannot parse address 192.193.194.195.196, not IP or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_destination_net_invalid(self): - """ test creation of a new rule with an invalid destination network """ - obj = dict(name='one_rule', source='any', destination='192.193.194.195/256', interface='lan') + """test creation of a new rule with an invalid destination network""" + obj = dict( + name="one_rule", + source="any", + destination="192.193.194.195/256", + interface="lan", + ) msg = "Cannot parse address 192.193.194.195/256, not IP or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_source_self_lan(self): - """ test creation of a new rule with self""" - obj = dict(name='one_rule', source='(self)', destination='any', interface='lan') + """test creation of a new rule with self""" + obj = dict(name="one_rule", source="(self)", destination="any", interface="lan") command = "create rule 'one_rule' on 'lan', source='(self)', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_ip_to_ip(self): - """ test creation of a new rule with valid ips """ - obj = dict(name='one_rule', source='10.10.1.1', destination='10.10.10.1', interface='lan') + """test creation of a new rule with valid ips""" + obj = dict( + name="one_rule", + source="10.10.1.1", + destination="10.10.10.1", + interface="lan", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.1', destination='10.10.10.1'" self.do_module_test(obj, command=command) @@ -336,8 +615,13 @@ def test_rule_create_ip6_to_ip6(self): self.do_module_test(obj, command=command) def test_rule_create_net_to_net(self): - """ test creation of a new rule valid networks """ - obj = dict(name='one_rule', source='10.10.1.0/24', destination='10.10.10.0/24', interface='lan') + """test creation of a new rule valid networks""" + obj = dict( + name="one_rule", + source="10.10.1.0/24", + destination="10.10.10.0/24", + interface="lan", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.0/24', destination='10.10.10.0/24'" self.do_module_test(obj, command=command) @@ -348,307 +632,582 @@ def test_rule_create_net6_to_net6(self): self.do_module_test(obj, command=command) def test_rule_create_net_interface(self): - """ test creation of a new rule with valid interface """ - obj = dict(name='one_rule', source='NET:lan', destination='any', interface='lan') + """test creation of a new rule with valid interface""" + obj = dict( + name="one_rule", source="NET:lan", destination="any", interface="lan" + ) command = "create rule 'one_rule' on 'lan', source='NET:lan', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_net_interface_invalid(self): - """ test creation of a new rule with invalid interface """ - obj = dict(name='one_rule', source='NET:invalid_lan', destination='any', interface='lan') + """test creation of a new rule with invalid interface""" + obj = dict( + name="one_rule", + source="NET:invalid_lan", + destination="any", + interface="lan", + ) msg = "invalid_lan is not a valid interface" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_net_interface_invalid2(self): - """ test creation of a new rule with invalid interface """ - obj = dict(name='one_rule', source='NET:', destination='any', interface='lan') + """test creation of a new rule with invalid interface""" + obj = dict(name="one_rule", source="NET:", destination="any", interface="lan") msg = "Cannot parse address NET:" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_ip_interface(self): - """ test creation of a new rule with valid interface """ - obj = dict(name='one_rule', source='IP:vt1', destination='any', interface='lan') + """test creation of a new rule with valid interface""" + obj = dict(name="one_rule", source="IP:vt1", destination="any", interface="lan") command = "create rule 'one_rule' on 'lan', source='IP:vt1', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_ip_interface_with_port(self): - """ test creation of a new rule with valid interface """ - obj = dict(name='one_rule', source='IP:vt1:22', destination='any', interface='lan', protocol='tcp') + """test creation of a new rule with valid interface""" + obj = dict( + name="one_rule", + source="IP:vt1:22", + destination="any", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='IP:vt1:22', destination='any', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_ip_interface_invalid(self): - """ test creation of a new rule with invalid interface """ - obj = dict(name='one_rule', source='IP:invalid_lan', destination='any', interface='lan') + """test creation of a new rule with invalid interface""" + obj = dict( + name="one_rule", source="IP:invalid_lan", destination="any", interface="lan" + ) msg = "invalid_lan is not a valid interface" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_interface(self): - """ test creation of a new rule with valid interface """ - obj = dict(name='one_rule', source='vpn', destination='any', interface='lan') + """test creation of a new rule with valid interface""" + obj = dict(name="one_rule", source="vpn", destination="any", interface="lan") command = "create rule 'one_rule' on 'lan', source='vpn', destination='any'" self.do_module_test(obj, command=command) def test_rule_create_port_number(self): - """ test creation of a new rule with port """ - obj = dict(name='one_rule', source='10.10.1.1', destination='10.10.10.1:80', interface='lan', protocol='tcp') + """test creation of a new rule with port""" + obj = dict( + name="one_rule", + source="10.10.1.1", + destination="10.10.10.1:80", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.1', destination='10.10.10.1:80', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_port_alias(self): - """ test creation of a new rule with port alias """ - obj = dict(name='one_rule', source='10.10.1.1', destination='10.10.10.1:port_http', interface='lan', protocol='tcp') + """test creation of a new rule with port alias""" + obj = dict( + name="one_rule", + source="10.10.1.1", + destination="10.10.10.1:port_http", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.1', destination='10.10.10.1:port_http', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_urltable_port_alias(self): - """ test creation of a new rule with urltable port alias """ - obj = dict(name='one_rule', source='10.10.1.1', destination='10.10.10.1:acme_corp_ports', interface='lan', protocol='tcp') + """test creation of a new rule with urltable port alias""" + obj = dict( + name="one_rule", + source="10.10.1.1", + destination="10.10.10.1:acme_corp_ports", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.1', destination='10.10.10.1:acme_corp_ports', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_port_range(self): - """ test creation of a new rule with range of ports """ - obj = dict(name='one_rule', source='10.10.1.1:30000-40000', destination='10.10.10.1', interface='lan', protocol='tcp') + """test creation of a new rule with range of ports""" + obj = dict( + name="one_rule", + source="10.10.1.1:30000-40000", + destination="10.10.10.1", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.1:30000-40000', destination='10.10.10.1', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_port_alias_range(self): - """ test creation of a new rule with range of alias ports """ - obj = dict(name='one_rule', source='10.10.1.1:port_ssh-port_http', destination='10.10.10.1', interface='lan', protocol='tcp') + """test creation of a new rule with range of alias ports""" + obj = dict( + name="one_rule", + source="10.10.1.1:port_ssh-port_http", + destination="10.10.10.1", + interface="lan", + protocol="tcp", + ) command = "create rule 'one_rule' on 'lan', source='10.10.1.1:port_ssh-port_http', destination='10.10.10.1', protocol='tcp'" self.do_module_test(obj, command=command) def test_rule_create_port_alias_range_invalid_1(self): - """ test creation of a new rule with range of invalid alias ports """ - obj = dict(name='one_rule', source='10.10.1.1:port_ssh-openvpn_port', destination='10.10.10.1', interface='lan') + """test creation of a new rule with range of invalid alias ports""" + obj = dict( + name="one_rule", + source="10.10.1.1:port_ssh-openvpn_port", + destination="10.10.10.1", + interface="lan", + ) msg = "Cannot parse port openvpn_port, not port number or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_port_alias_range_invalid_2(self): - """ test creation of a new rule with range of invalid alias ports """ - obj = dict(name='one_rule', source='10.10.1.1:-openvpn_port', destination='10.10.10.1', interface='lan') + """test creation of a new rule with range of invalid alias ports""" + obj = dict( + name="one_rule", + source="10.10.1.1:-openvpn_port", + destination="10.10.10.1", + interface="lan", + ) msg = "Cannot parse port -openvpn_port" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_port_alias_range_invalid_3(self): - """ test creation of a new rule with range of invalid alias ports """ - obj = dict(name='one_rule', source='10.10.1.1:port_ssh-65537', destination='10.10.10.1', interface='lan') + """test creation of a new rule with range of invalid alias ports""" + obj = dict( + name="one_rule", + source="10.10.1.1:port_ssh-65537", + destination="10.10.10.1", + interface="lan", + ) msg = "Cannot parse port 65537, not port number or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_port_number_invalid(self): - """ test creation of a new rule with invalid port number """ - obj = dict(name='one_rule', source='10.10.1.1:65536', destination='10.10.10.1', interface='lan', protocol='tcp') + """test creation of a new rule with invalid port number""" + obj = dict( + name="one_rule", + source="10.10.1.1:65536", + destination="10.10.10.1", + interface="lan", + protocol="tcp", + ) msg = "Cannot parse port 65536, not port number or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_port_alias_invalid(self): - """ test creation of a new rule with invalid port alias """ - obj = dict(name='one_rule', source='10.10.1.1:openvpn_port', destination='10.10.10.1', interface='lan') + """test creation of a new rule with invalid port alias""" + obj = dict( + name="one_rule", + source="10.10.1.1:openvpn_port", + destination="10.10.10.1", + interface="lan", + ) msg = "Cannot parse port openvpn_port, not port number or alias" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_negate_source(self): - """ test creation of a new rule with a not source """ - obj = dict(name='one_rule', source='!srv_admin', destination='any', interface='lan') - command = "create rule 'one_rule' on 'lan', source='!srv_admin', destination='any'" + """test creation of a new rule with a not source""" + obj = dict( + name="one_rule", source="!srv_admin", destination="any", interface="lan" + ) + command = ( + "create rule 'one_rule' on 'lan', source='!srv_admin', destination='any'" + ) self.do_module_test(obj, command=command) def test_rule_create_negate_destination(self): - """ test creation of a new rule with a not destination """ - obj = dict(name='one_rule', source='any', destination='!srv_admin', interface='lan') - command = "create rule 'one_rule' on 'lan', source='any', destination='!srv_admin'" + """test creation of a new rule with a not destination""" + obj = dict( + name="one_rule", source="any", destination="!srv_admin", interface="lan" + ) + command = ( + "create rule 'one_rule' on 'lan', source='any', destination='!srv_admin'" + ) self.do_module_test(obj, command=command) def test_rule_create_separator_top(self): - """ test creation of a new rule at top """ - obj = dict(name='one_rule', source='any', destination='any', interface='vt1', after='top') + """test creation of a new rule at top""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="vt1", + after="top", + ) command = "create rule 'one_rule' on 'vt1', source='any', destination='any', after='top'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) - self.check_separator_idx(obj['interface'], 'test_sep1', 1) - self.check_separator_idx(obj['interface'], 'test_sep2', 4) + self.check_separator_idx(obj["interface"], "test_sep1", 1) + self.check_separator_idx(obj["interface"], "test_sep2", 4) def test_rule_create_separator_bottom(self): - """ test creation of a new rule at bottom """ - obj = dict(name='one_rule', source='any', destination='any', interface='vt1', before='bottom') + """test creation of a new rule at bottom""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="vt1", + before="bottom", + ) command = "create rule 'one_rule' on 'vt1', source='any', destination='any', before='bottom'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 3) - self.check_separator_idx(obj['interface'], 'test_sep1', 0) - self.check_separator_idx(obj['interface'], 'test_sep2', 3) + self.check_separator_idx(obj["interface"], "test_sep1", 0) + self.check_separator_idx(obj["interface"], "test_sep2", 3) def test_rule_create_separator_before_first(self): - """ test creation of a new rule before first rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='vt1', before='r1') + """test creation of a new rule before first rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="vt1", + before="r1", + ) command = "create rule 'one_rule' on 'vt1', source='any', destination='any', before='r1'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) - self.check_separator_idx(obj['interface'], 'test_sep1', 0) - self.check_separator_idx(obj['interface'], 'test_sep2', 4) + self.check_separator_idx(obj["interface"], "test_sep1", 0) + self.check_separator_idx(obj["interface"], "test_sep2", 4) def test_rule_create_separator_after_third(self): - """ test creation of a new rule after third rule """ - obj = dict(name='one_rule', source='any', destination='any', interface='vt1', after='r3') + """test creation of a new rule after third rule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="vt1", + after="r3", + ) command = "create rule 'one_rule' on 'vt1', source='any', destination='any', after='r3'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 3) - self.check_separator_idx(obj['interface'], 'test_sep1', 0) - self.check_separator_idx(obj['interface'], 'test_sep2', 4) + self.check_separator_idx(obj["interface"], "test_sep1", 0) + self.check_separator_idx(obj["interface"], "test_sep2", 4) def test_rule_create_queue(self): - """ test creation of a new rule with default queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', queue='one_queue') + """test creation of a new rule with default queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + queue="one_queue", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', queue='one_queue'" self.do_module_test(obj, command=command) def test_rule_create_queue_ack(self): - """ test creation of a new rule with default queue and ack queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', queue='one_queue', ackqueue='another_queue') + """test creation of a new rule with default queue and ack queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + queue="one_queue", + ackqueue="another_queue", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', queue='one_queue', ackqueue='another_queue'" self.do_module_test(obj, command=command) def test_rule_create_queue_ack_without_default(self): - """ test creation of a new rule with ack queue and without default queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', ackqueue='another_queue') + """test creation of a new rule with ack queue and without default queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + ackqueue="another_queue", + ) msg = "A default queue must be selected when an acknowledge queue is also selected" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_queue_same(self): - """ test creation of a new rule with same default queue and ack queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', queue='one_queue', ackqueue='one_queue') + """test creation of a new rule with same default queue and ack queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + queue="one_queue", + ackqueue="one_queue", + ) msg = "Acknowledge queue and default queue cannot be the same" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_queue_invalid(self): - """ test creation of a new rule with invalid default queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', queue='acme_queue') + """test creation of a new rule with invalid default queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + queue="acme_queue", + ) msg = "Failed to find enabled queue=acme_queue" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_queue_invalid_ack(self): - """ test creation of a new rule with default queue and invalid ack queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', queue='one_queue', ackqueue='acme_queue') + """test creation of a new rule with default queue and invalid ack queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + queue="one_queue", + ackqueue="acme_queue", + ) msg = "Failed to find enabled ackqueue=acme_queue" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_limiter(self): - """ test creation of a new rule with in_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='one_limiter') + """test creation of a new rule with in_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="one_limiter", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', in_queue='one_limiter'" self.do_module_test(obj, command=command) def test_rule_create_limiter_out(self): - """ test creation of a new rule with in_queue and out_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='one_limiter', out_queue='another_limiter') + """test creation of a new rule with in_queue and out_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="one_limiter", + out_queue="another_limiter", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', in_queue='one_limiter', out_queue='another_limiter'" self.do_module_test(obj, command=command) def test_rule_create_limiter_disabled(self): - """ test creation of a new rule with disabled in_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='disabled_limiter') + """test creation of a new rule with disabled in_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="disabled_limiter", + ) msg = "Failed to find enabled in_queue=disabled_limiter" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_limiter_out_without_in(self): - """ test creation of a new rule with out_queue and without in_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', out_queue='another_limiter') + """test creation of a new rule with out_queue and without in_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + out_queue="another_limiter", + ) msg = "A queue must be selected for the In direction before selecting one for Out too" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_limiter_same(self): - """ test creation of a new rule with same in_queue and out_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='one_limiter', out_queue='one_limiter') + """test creation of a new rule with same in_queue and out_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="one_limiter", + out_queue="one_limiter", + ) msg = "In and Out Queue cannot be the same" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_limiter_invalid(self): - """ test creation of a new rule with invalid in_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='acme_queue') + """test creation of a new rule with invalid in_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="acme_queue", + ) msg = "Failed to find enabled in_queue=acme_queue" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_limiter_invalid_out(self): - """ test creation of a new rule with in_queue and invalid out_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='one_limiter', out_queue='acme_queue') + """test creation of a new rule with in_queue and invalid out_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="one_limiter", + out_queue="acme_queue", + ) msg = "Failed to find enabled out_queue=acme_queue" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_limiter_floating_any(self): - """ test creation of a new rule with in_queue and invalid out_queue """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', in_queue='one_limiter', floating='yes', direction='any') + """test creation of a new rule with in_queue and invalid out_queue""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + in_queue="one_limiter", + floating="yes", + direction="any", + ) msg = "Limiters can not be used in Floating rules without choosing a direction" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_gateway(self): - """ test creation of a new rule with gateway """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', gateway='GW_LAN') + """test creation of a new rule with gateway""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + gateway="GW_LAN", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', gateway='GW_LAN'" self.do_module_test(obj, command=command) def test_rule_create_gateway_invalid(self): - """ test creation of a new rule with invalid gateway """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', gateway='GW_WLAN') + """test creation of a new rule with invalid gateway""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + gateway="GW_WLAN", + ) msg = 'Gateway "GW_WLAN" does not exist or does not match target rule ip protocol.' self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_gateway_invalid_ipprotocol(self): - """ test creation of a new rule with gateway """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', ipprotocol='inet6', gateway='GW_LAN') - msg = 'Gateway "GW_LAN" does not exist or does not match target rule ip protocol.' + """test creation of a new rule with gateway""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + ipprotocol="inet6", + gateway="GW_LAN", + ) + msg = ( + 'Gateway "GW_LAN" does not exist or does not match target rule ip protocol.' + ) self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_gateway_floating(self): - """ test creation of a new floating rule with gateway """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', floating='yes', direction='in', gateway='GW_LAN') + """test creation of a new floating rule with gateway""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + floating="yes", + direction="in", + gateway="GW_LAN", + ) command = "create rule 'one_rule' on 'floating(lan)', source='any', destination='any', direction='in', gateway='GW_LAN'" self.do_module_test(obj, command=command) def test_rule_create_gateway_floating_any(self): - """ test creation of a new floating rule with gateway """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', floating='yes', direction='any', gateway='GW_LAN') + """test creation of a new floating rule with gateway""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + floating="yes", + direction="any", + gateway="GW_LAN", + ) msg = "Gateways can not be used in Floating rules without choosing a direction" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_gateway_group(self): - """ test creation of a new rule with gateway group """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', gateway='GWGroup') + """test creation of a new rule with gateway group""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + gateway="GWGroup", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', gateway='GWGroup'" self.do_module_test(obj, command=command) def test_rule_create_gateway_group_invalid_ipprotocol(self): - """ test creation of a new rule with gateway group """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', ipprotocol='inet6', gateway='GWGroup') + """test creation of a new rule with gateway group""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + ipprotocol="inet6", + gateway="GWGroup", + ) msg = 'Gateway "GWGroup" does not exist or does not match target rule ip protocol.' self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_tracker(self): - """ test creation of a new rule with tracker """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', tracker='1234') + """test creation of a new rule with tracker""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + tracker="1234", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', tracker='1234'" self.do_module_test(obj, command=command) def test_rule_create_tracker_leading0(self): - """ test creation of a new rule with tracker with a leading 0 """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', tracker='0100000101') + """test creation of a new rule with tracker with a leading 0""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + tracker="0100000101", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', tracker='0100000101'" self.do_module_test(obj, command=command) def test_rule_create_tracker_invalid(self): - """ test creation of a new rule with invalid tracker """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', tracker='-1234') - msg = 'tracker -1234 must be a positive integer' + """test creation of a new rule with invalid tracker""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + tracker="-1234", + ) + msg = "tracker -1234 must be a positive integer" self.do_module_test(obj, failed=True, msg=msg) def test_rule_create_schedule(self): - """ test creation of a new rule with schedule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', sched='workdays') + """test creation of a new rule with schedule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + sched="workdays", + ) command = "create rule 'one_rule' on 'lan', source='any', destination='any', sched='workdays'" self.do_module_test(obj, command=command) def test_rule_create_schedule_invalid(self): - """ test creation of a new rule with invalid schedule """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan', sched='acme') - msg = 'Schedule acme does not exist' + """test creation of a new rule with invalid schedule""" + obj = dict( + name="one_rule", + source="any", + destination="any", + interface="lan", + sched="acme", + ) + msg = "Schedule acme does not exist" self.do_module_test(obj, failed=True, msg=msg) diff --git a/tests/unit/plugins/modules/test_pfsense_rule_misc.py b/tests/unit/plugins/modules/test_pfsense_rule_misc.py index 87b0b417..daccbbb0 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule_misc.py +++ b/tests/unit/plugins/modules/test_pfsense_rule_misc.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -10,18 +11,25 @@ if sys.version_info < (2, 7): pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") -from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import set_module_args +from ansible_collections.community.internal_test_tools.tests.unit.plugins.modules.utils import ( + set_module_args, +) from .test_pfsense_rule import TestPFSenseRuleModule class TestPFSenseRuleMiscModule(TestPFSenseRuleModule): - ############## # delete # def test_rule_delete(self): - """ test deleting a rule """ - obj = dict(name='test_rule_3', source='any', destination='any', interface='wan', protocol='tcp') + """test deleting a rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any", + interface="wan", + protocol="tcp", + ) command = "delete rule 'test_rule_3' on 'wan'" self.do_module_test(obj, command=command, delete=True) @@ -29,8 +37,8 @@ def test_rule_delete(self): # misc # def test_check_mode(self): - """ test check mode """ - obj = dict(name='one_rule', source='any', destination='any', interface='lan') + """test check mode""" + obj = dict(name="one_rule", source="any", destination="any", interface="lan") with set_module_args(self.args_from_var(obj, _ansible_check_mode=True)): self.execute_module(changed=True) self.assertFalse(self.load_xml_result()) diff --git a/tests/unit/plugins/modules/test_pfsense_rule_noop.py b/tests/unit/plugins/modules/test_pfsense_rule_noop.py index 526ba99c..c3ea0960 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule_noop.py +++ b/tests/unit/plugins/modules/test_pfsense_rule_noop.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -14,152 +15,346 @@ class TestPFSenseRuleNoopModule(TestPFSenseRuleModule): - ############################ # rule noop tests # def test_rule_noop_action(self): - """ test not updating action of a rule to block """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', action='pass', protocol='tcp') + """test not updating action of a rule to block""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + action="pass", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_disabled(self): - """ test not updating disabled of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', disabled='False', protocol='tcp') + """test not updating disabled of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + disabled="False", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_enabled(self): - """ test not updating disabled of a rule """ - obj = dict(name='test_lan_100_1', source='any', destination='any', interface='lan_100', disabled='True', protocol='tcp') + """test not updating disabled of a rule""" + obj = dict( + name="test_lan_100_1", + source="any", + destination="any", + interface="lan_100", + disabled="True", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_disabled_default(self): - """ test not updating disabled of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', protocol='tcp') + """test not updating disabled of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_floating_interface(self): - """ test not updating interface of a floating rule """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp') + """test not updating interface of a floating rule""" + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_floating_direction(self): - """ test not updating direction of a rule to out """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp') + """test not updating direction of a rule to out""" + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_inet(self): - """ test not updating ippprotocol of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', ipprotocol='inet', protocol='tcp') + """test not updating ippprotocol of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + ipprotocol="inet", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_protocol(self): - """ test not updating protocol of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', protocol='tcp') + """test not updating protocol of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_log_no(self): - """ test not updating log of a rule to no """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', log='no', protocol='tcp') + """test not updating log of a rule to no""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + log="no", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_log_yes(self): - """ test not updating log of a rule to no """ - obj = dict(name='test_rule_2', source='any', destination='any', interface='wan', log='yes', protocol='tcp') + """test not updating log of a rule to no""" + obj = dict( + name="test_rule_2", + source="any", + destination="any", + interface="wan", + log="yes", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_log_default(self): - """ test not updating log of a rule to default """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', log='no', protocol='tcp') + """test not updating log of a rule to default""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + log="no", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_source_and_destination(self): - """ test not updating source and destination of a rule """ - obj = dict(name='ads_to_ads_tcp_2_3', source='ad_poc3:port_ldap_ssl', destination='ad_poc1:port_ldap_ssl', interface='lan', protocol='tcp') + """test not updating source and destination of a rule""" + obj = dict( + name="ads_to_ads_tcp_2_3", + source="ad_poc3:port_ldap_ssl", + destination="ad_poc1:port_ldap_ssl", + interface="lan", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_negate_source(self): - """ test creation of a new rule with a not source """ - obj = dict(name='not_rule_src', source='!srv_admin', destination='any:port_ssh', interface='lan', protocol='tcp') + """test creation of a new rule with a not source""" + obj = dict( + name="not_rule_src", + source="!srv_admin", + destination="any:port_ssh", + interface="lan", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_negate_destination(self): - """ test creation of a new rule with a not destination """ - obj = dict(name='not_rule_dst', source='any', destination='!srv_admin:port_ssh', interface='lan', protocol='tcp') + """test creation of a new rule with a not destination""" + obj = dict( + name="not_rule_dst", + source="any", + destination="!srv_admin:port_ssh", + interface="lan", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_before(self): - """ test not updating position of a rule to before another """ - obj = dict(name='test_rule_2', source='any', destination='any', interface='wan', log='yes', protocol='tcp', before='test_rule_3') + """test not updating position of a rule to before another""" + obj = dict( + name="test_rule_2", + source="any", + destination="any", + interface="wan", + log="yes", + protocol="tcp", + before="test_rule_3", + ) self.do_module_test(obj, changed=False) def test_rule_noop_before_bottom(self): - """ test not updating position of a rule to bottom """ - obj = dict(name='antilock_out_3', source='any', destination='any:443', interface='wan', protocol='tcp', before='bottom') + """test not updating position of a rule to bottom""" + obj = dict( + name="antilock_out_3", + source="any", + destination="any:443", + interface="wan", + protocol="tcp", + before="bottom", + ) self.do_module_test(obj, changed=False) def test_rule_noop_position_bottom(self): - """ test not updating position of a rule to bottom """ - obj = dict(name='antilock_out_3', source='any', destination='any:443', interface='wan', protocol='tcp') + """test not updating position of a rule to bottom""" + obj = dict( + name="antilock_out_3", + source="any", + destination="any:443", + interface="wan", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_position_middle(self): - """ test not updating position of a rule to before another """ - obj = dict(name='test_rule_2', source='any', destination='any', interface='wan', log='yes', protocol='tcp') + """test not updating position of a rule to before another""" + obj = dict( + name="test_rule_2", + source="any", + destination="any", + interface="wan", + log="yes", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_after(self): - """ test not updating position of a rule to after another rule """ - obj = dict(name='test_rule_2', source='any', destination='any', interface='wan', log='yes', protocol='tcp', after='test_rule') + """test not updating position of a rule to after another rule""" + obj = dict( + name="test_rule_2", + source="any", + destination="any", + interface="wan", + log="yes", + protocol="tcp", + after="test_rule", + ) self.do_module_test(obj, changed=False) def test_rule_noop_after_top(self): - """ test not updating position of a rule to top """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', log='no', protocol='tcp', after='top') + """test not updating position of a rule to top""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + log="no", + protocol="tcp", + after="top", + ) self.do_module_test(obj, changed=False) def test_rule_noop_separator_top(self): - """ test not updating position of a rule to top """ - obj = dict(name='r1', source='any', destination='any', interface='vt1', protocol='tcp') + """test not updating position of a rule to top""" + obj = dict( + name="r1", source="any", destination="any", interface="vt1", protocol="tcp" + ) self.do_module_test(obj, changed=False) def test_rule_noop_separator_bottom(self): - """ test not updating position of a rule to bottom """ - obj = dict(name='r3', source='any', destination='any', interface='vt1', protocol='tcp') + """test not updating position of a rule to bottom""" + obj = dict( + name="r3", source="any", destination="any", interface="vt1", protocol="tcp" + ) self.do_module_test(obj, changed=False) def test_rule_noop_queue_ack(self): - """ test updating queue of a rule """ - obj = dict(name='test_lan_100_2', source='any', destination='any', interface='lan_100', queue='one_queue', ackqueue='another_queue', protocol='tcp') + """test updating queue of a rule""" + obj = dict( + name="test_lan_100_2", + source="any", + destination="any", + interface="lan_100", + queue="one_queue", + ackqueue="another_queue", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_queue(self): - """ test updating queue and ackqueue of a rule """ - obj = dict(name='test_lan_100_3', source='any', destination='any', interface='lan_100', queue='one_queue', protocol='tcp') + """test updating queue and ackqueue of a rule""" + obj = dict( + name="test_lan_100_3", + source="any", + destination="any", + interface="lan_100", + queue="one_queue", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_limiter_out(self): - """ test updating queue of a rule """ + """test updating queue of a rule""" obj = dict( - name='test_lan_100_4', source='any', destination='any', interface='lan_100', in_queue='one_limiter', out_queue='another_limiter', protocol='tcp') + name="test_lan_100_4", + source="any", + destination="any", + interface="lan_100", + in_queue="one_limiter", + out_queue="another_limiter", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_limiter_in(self): - """ test updating queue and ackqueue of a rule """ - obj = dict(name='test_lan_100_5', source='any', destination='any', interface='lan_100', in_queue='one_limiter', protocol='tcp') + """test updating queue and ackqueue of a rule""" + obj = dict( + name="test_lan_100_5", + source="any", + destination="any", + interface="lan_100", + in_queue="one_limiter", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_tracker(self): - """ test updating tracker of a rule """ - obj = dict(name='test_lan_100_5', source='any', destination='any', interface='lan_100', in_queue='one_limiter', protocol='tcp', tracker=1545574416) + """test updating tracker of a rule""" + obj = dict( + name="test_lan_100_5", + source="any", + destination="any", + interface="lan_100", + in_queue="one_limiter", + protocol="tcp", + tracker=1545574416, + ) self.do_module_test(obj, changed=False) def test_rule_noop_tracker(self): - """ test updating tracker of a rule """ - obj = dict(name='test_lan_100_5', source='any', destination='any', interface='lan_100', in_queue='one_limiter', protocol='tcp') + """test updating tracker of a rule""" + obj = dict( + name="test_lan_100_5", + source="any", + destination="any", + interface="lan_100", + in_queue="one_limiter", + protocol="tcp", + ) self.do_module_test(obj, changed=False) def test_rule_noop_schedule(self): - """ test updating scheduling of a rule """ - obj = dict(name='test_rule_sched', source='any', destination='any', interface='lan_100', action='pass', protocol='tcp', sched='workdays') + """test updating scheduling of a rule""" + obj = dict( + name="test_rule_sched", + source="any", + destination="any", + interface="lan_100", + action="pass", + protocol="tcp", + sched="workdays", + ) self.do_module_test(obj, changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_rule_separator.py b/tests/unit/plugins/modules/test_pfsense_rule_separator.py index 08e22c9a..ffa893ab 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule_separator.py +++ b/tests/unit/plugins/modules/test_pfsense_rule_separator.py @@ -1,13 +1,16 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import sys import pytest from ansible_collections.pfsensible.core.plugins.modules import pfsense_rule_separator -from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import PFSenseRuleSeparatorModule +from ansible_collections.pfsensible.core.plugins.module_utils.rule_separator import ( + PFSenseRuleSeparatorModule, +) from .pfsense_module import TestPFSenseModule if sys.version_info < (2, 7): @@ -15,128 +18,138 @@ class TestPFSenseRuleSeparatorModule(TestPFSenseModule): - module = pfsense_rule_separator def __init__(self, *args, **kwargs): super(TestPFSenseRuleSeparatorModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_rule_separator_config.xml' + self.config_file = "pfsense_rule_separator_config.xml" self.pfmodule = PFSenseRuleSeparatorModule def get_target_elt(self, obj, absent=False, module_result=None): - """ get separator from XML """ - if obj.get('floating'): - interface = 'floatingrules' + """get separator from XML""" + if obj.get("floating"): + interface = "floatingrules" else: - interface = self.unalias_interface(obj['interface']) + interface = self.unalias_interface(obj["interface"]) - filter_elt = self.assert_find_xml_elt(self.xml_result, 'filter') - separator_elt = self.assert_find_xml_elt(filter_elt, 'separator') + filter_elt = self.assert_find_xml_elt(self.xml_result, "filter") + separator_elt = self.assert_find_xml_elt(filter_elt, "separator") iface_elt = self.assert_find_xml_elt(separator_elt, interface) for separator_elt in iface_elt: - text_elt = separator_elt.find('text') - if text_elt is not None and text_elt.text == obj['name']: + text_elt = separator_elt.find("text") + if text_elt is not None and text_elt.text == obj["name"]: if absent: - self.fail('Separator ' + obj['name'] + ' found on interface ' + interface) + self.fail( + "Separator " + obj["name"] + " found on interface " + interface + ) return separator_elt if not absent: - self.fail('Separator ' + obj['name'] + ' not found on interface ' + interface) + self.fail( + "Separator " + obj["name"] + " not found on interface " + interface + ) return None def check_target_elt(self, obj, target_elt): - """ check XML separator definition """ - if obj.get('floating'): - interface = 'floatingrules' + """check XML separator definition""" + if obj.get("floating"): + interface = "floatingrules" else: - interface = self.unalias_interface(obj['interface']) + interface = self.unalias_interface(obj["interface"]) - self.assert_xml_elt_equal(target_elt, 'if', interface) + self.assert_xml_elt_equal(target_elt, "if", interface) - if 'color' not in obj: - self.assert_xml_elt_equal(target_elt, 'color', 'bg-info') + if "color" not in obj: + self.assert_xml_elt_equal(target_elt, "color", "bg-info") else: - self.assert_xml_elt_equal(target_elt, 'color', 'bg-' + obj['color']) + self.assert_xml_elt_equal(target_elt, "color", "bg-" + obj["color"]) def check_separator_idx(self, separator, expected_idx): - """ test the logical position of separator """ + """test the logical position of separator""" separator_elt = self.get_target_elt(separator) - row_elt = self.assert_find_xml_elt(separator_elt, 'row') - idx = int(row_elt.text.replace('fr', '')) + row_elt = self.assert_find_xml_elt(separator_elt, "row") + idx = int(row_elt.text.replace("fr", "")) if idx != expected_idx: - self.fail('Idx of separator ' + separator['name'] + ' if wrong: ' + str(idx) + ', expected: ' + str(expected_idx)) + self.fail( + "Idx of separator " + + separator["name"] + + " if wrong: " + + str(idx) + + ", expected: " + + str(expected_idx) + ) ############## # hosts # def test_separator_create(self): - """ test creation of a new separator """ - separator = dict(name='voip', interface='lan_100') + """test creation of a new separator""" + separator = dict(name="voip", interface="lan_100") command = "create rule_separator 'voip' on 'lan_100', color='info'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 6) def test_separator_create_floating(self): - """ test creation of a new separator """ - separator = dict(name='voip', floating=True) + """test creation of a new separator""" + separator = dict(name="voip", floating=True) command = "create rule_separator 'voip' on 'floating', color='info'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 0) def test_separator_create_top(self): - """ test creation of a new separator at top """ - separator = dict(name='voip', interface='lan_100', after='top') + """test creation of a new separator at top""" + separator = dict(name="voip", interface="lan_100", after="top") command = "create rule_separator 'voip' on 'lan_100', color='info', after='top'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 0) def test_separator_create_bottom(self): - """ test creation of a new separator at bottom """ - separator = dict(name='voip', interface='lan', before='bottom') + """test creation of a new separator at bottom""" + separator = dict(name="voip", interface="lan", before="bottom") command = "create rule_separator 'voip' on 'lan', color='info', before='bottom'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 14) def test_separator_create_after(self): - """ test creation of a new separator at bottom """ - separator = dict(name='voip', interface='lan', after='antilock_out_1') + """test creation of a new separator at bottom""" + separator = dict(name="voip", interface="lan", after="antilock_out_1") command = "create rule_separator 'voip' on 'lan', color='info', after='antilock_out_1'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 1) def test_separator_create_before(self): - """ test creation of a new separator at bottom """ - separator = dict(name='voip', interface='lan', before='antilock_out_2') + """test creation of a new separator at bottom""" + separator = dict(name="voip", interface="lan", before="antilock_out_2") command = "create rule_separator 'voip' on 'lan', color='info', before='antilock_out_2'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 1) def test_separator_delete(self): - """ test deletion of a separator """ - separator = dict(name='test_separator', interface='lan') + """test deletion of a separator""" + separator = dict(name="test_separator", interface="lan") command = "delete rule_separator 'test_separator' on 'lan'" self.do_module_test(separator, command=command, delete=True) def test_separator_delete_inexistent(self): - """ test deletion of an inexistent separator """ - separator = dict(name='test_separator', interface='wan') - self.do_module_test(separator, command='', changed=False, delete=True) + """test deletion of an inexistent separator""" + separator = dict(name="test_separator", interface="wan") + self.do_module_test(separator, command="", changed=False, delete=True) def test_separator_update_noop(self): - """ test changing nothing to a separator """ - separator = dict(name='test_separator', interface='lan', color='info') + """test changing nothing to a separator""" + separator = dict(name="test_separator", interface="lan", color="info") self.do_module_test(separator, changed=False) def test_separator_update_color(self): - """ test updating color of a separator """ - separator = dict(name='test_separator', interface='lan', color='warning') + """test updating color of a separator""" + separator = dict(name="test_separator", interface="lan", color="warning") command = "update rule_separator 'test_separator' on 'lan' set color='warning'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 1) def test_separator_update_position(self): - """ test updating position of a separator """ - separator = dict(name='test_separator', interface='lan', after='top') + """test updating position of a separator""" + separator = dict(name="test_separator", interface="lan", after="top") command = "update rule_separator 'test_separator' on 'lan' set color='info', after='top'" self.do_module_test(separator, command=command) self.check_separator_idx(separator, 0) diff --git a/tests/unit/plugins/modules/test_pfsense_rule_update.py b/tests/unit/plugins/modules/test_pfsense_rule_update.py index a90574c9..812c8955 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule_update.py +++ b/tests/unit/plugins/modules/test_pfsense_rule_update.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -14,346 +15,729 @@ class TestPFSenseRuleUpdateModule(TestPFSenseRuleModule): - ############################ # rule update tests # def test_rule_update_action(self): - """ test updating action of a rule to block """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', action='block', protocol='tcp') + """test updating action of a rule to block""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + action="block", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set action='block'" self.do_module_test(obj, command=command) def test_rule_update_disabled(self): - """ test updating disabled of a rule to True """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', disabled='True', protocol='tcp') + """test updating disabled of a rule to True""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + disabled="True", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set disabled=True" self.do_module_test(obj, command=command) def test_rule_update_enabled(self): - """ test updating disabled of a rule to False """ - obj = dict(name='test_lan_100_1', source='any', destination='any', interface='lan_100', disabled='False', protocol='tcp') + """test updating disabled of a rule to False""" + obj = dict( + name="test_lan_100_1", + source="any", + destination="any", + interface="lan_100", + disabled="False", + protocol="tcp", + ) command = "update rule 'test_lan_100_1' on 'lan_100' set disabled=False" self.do_module_test(obj, command=command) def test_rule_update_enabled_default(self): - """ test updating disabled of a rule to default """ - obj = dict(name='test_lan_100_1', source='any', destination='any', interface='lan_100', protocol='tcp') + """test updating disabled of a rule to default""" + obj = dict( + name="test_lan_100_1", + source="any", + destination="any", + interface="lan_100", + protocol="tcp", + ) command = "update rule 'test_lan_100_1' on 'lan_100' set disabled=False" self.do_module_test(obj, command=command) def test_rule_update_floating_interface(self): - """ test updating interface of a floating rule """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='lan', floating='yes', direction='any', protocol='tcp') - command = "update rule 'test_rule_floating' on 'floating(wan)' set interface='lan'" + """test updating interface of a floating rule""" + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="lan", + floating="yes", + direction="any", + protocol="tcp", + ) + command = ( + "update rule 'test_rule_floating' on 'floating(wan)' set interface='lan'" + ) self.do_module_test(obj, command=command) def test_rule_update_floating_interfaces(self): - """ test updating interfaces of a floating rule """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='lan,lan_100', floating='yes', direction='any', protocol='tcp') + """test updating interfaces of a floating rule""" + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="lan,lan_100", + floating="yes", + direction="any", + protocol="tcp", + ) command = "update rule 'test_rule_floating' on 'floating(wan)' set interface='lan,lan_100'" self.do_module_test(obj, command=command) def test_rule_update_floating_direction(self): - """ test updating direction of a rule to out """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='yes', direction='out', protocol='tcp') - command = "update rule 'test_rule_floating' on 'floating(wan)' set direction='out'" + """test updating direction of a rule to out""" + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="out", + protocol="tcp", + ) + command = ( + "update rule 'test_rule_floating' on 'floating(wan)' set direction='out'" + ) self.do_module_test(obj, command=command) def test_rule_update_floating_quick(self): - """ test updating quick match of a floating rule """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp', quick='yes') + """test updating quick match of a floating rule""" + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + quick="yes", + ) command = "update rule 'test_rule_floating' on 'floating(wan)' set quick=True" self.do_module_test(obj, command=command) def test_rule_update_floating_remove_quick(self): - """ test updating quick match of a floating rule """ - obj = dict(name='test_rule_floating_quick', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp') - command = "update rule 'test_rule_floating_quick' on 'floating(wan)' set quick=False" + """test updating quick match of a floating rule""" + obj = dict( + name="test_rule_floating_quick", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + ) + command = ( + "update rule 'test_rule_floating_quick' on 'floating(wan)' set quick=False" + ) self.do_module_test(obj, command=command) def test_rule_update_floating_yes(self): - """ test updating floating of a rule to yes - Since you can't change the floating mode of a rule, it should create a new rule + """test updating floating of a rule to yes + Since you can't change the floating mode of a rule, it should create a new rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp') + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + ) command = "create rule 'test_rule' on 'floating(wan)', source='any', destination='any', protocol='tcp', direction='any'" self.do_module_test(obj, command=command) - other_rule = dict(name='test_rule', source='any', destination='any', interface='wan', floating='no', protocol='tcp') + other_rule = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + floating="no", + protocol="tcp", + ) other_rule_elt = self.get_target_elt(other_rule) self.check_target_elt(other_rule, other_rule_elt) def test_rule_update_floating_no(self): - """ test updating floating of a rule to no - Since you can't change the floating mode of a rule, it should create a new rule + """test updating floating of a rule to no + Since you can't change the floating mode of a rule, it should create a new rule """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='no', direction='any', protocol='tcp') + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="no", + direction="any", + protocol="tcp", + ) command = "create rule 'test_rule_floating' on 'wan', source='any', destination='any', protocol='tcp', direction='any'" self.do_module_test(obj, command=command) - other_rule = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp') + other_rule = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + ) other_rule_elt = self.get_target_elt(other_rule) self.check_target_elt(other_rule, other_rule_elt) def test_rule_update_floating_default(self): - """ test updating floating of a rule to default (no) - Since you can't change the floating mode of a rule, it should create a new rule + """test updating floating of a rule to default (no) + Since you can't change the floating mode of a rule, it should create a new rule """ - obj = dict(name='test_rule_floating', source='any', destination='any', interface='wan', protocol='tcp') + obj = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + protocol="tcp", + ) command = "create rule 'test_rule_floating' on 'wan', source='any', destination='any', protocol='tcp'" self.do_module_test(obj, command=command) - other_rule = dict(name='test_rule_floating', source='any', destination='any', interface='wan', floating='yes', direction='any', protocol='tcp') + other_rule = dict( + name="test_rule_floating", + source="any", + destination="any", + interface="wan", + floating="yes", + direction="any", + protocol="tcp", + ) other_rule_elt = self.get_target_elt(other_rule) self.check_target_elt(other_rule, other_rule_elt) def test_rule_update_inet(self): - """ test updating ippprotocol of a rule to ipv4 and ipv6 """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', ipprotocol='inet46', protocol='tcp') + """test updating ippprotocol of a rule to ipv4 and ipv6""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + ipprotocol="inet46", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set ipprotocol='inet46'" self.do_module_test(obj, command=command) def test_rule_update_protocol_udp(self): - """ test updating protocol of a rule to udp """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', protocol='udp') + """test updating protocol of a rule to udp""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + protocol="udp", + ) command = "update rule 'test_rule' on 'wan' set protocol='udp'" self.do_module_test(obj, command=command) def test_rule_update_protocol_any(self): - """ test updating protocol of a rule to udp """ - obj = dict(name='r2', source='any', destination='any', interface='vt1', protocol='any') + """test updating protocol of a rule to udp""" + obj = dict( + name="r2", source="any", destination="any", interface="vt1", protocol="any" + ) command = "update rule 'r2' on 'vt1' set protocol='any'" self.do_module_test(obj, command=command) def test_rule_update_protocol_tcp_udp(self): - """ test updating protocol of a rule to tcp/udp """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', protocol='tcp/udp') + """test updating protocol of a rule to tcp/udp""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + protocol="tcp/udp", + ) command = "update rule 'test_rule' on 'wan' set protocol='tcp/udp'" self.do_module_test(obj, command=command) def test_rule_update_log_yes(self): - """ test updating log of a rule to yes """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', log='yes', protocol='tcp') + """test updating log of a rule to yes""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + log="yes", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set log=True" self.do_module_test(obj, command=command) def test_rule_update_log_no(self): - """ test updating log of a rule to no """ - obj = dict(name='test_rule_2', source='any', destination='any', interface='wan', log='no', protocol='tcp') + """test updating log of a rule to no""" + obj = dict( + name="test_rule_2", + source="any", + destination="any", + interface="wan", + log="no", + protocol="tcp", + ) command = "update rule 'test_rule_2' on 'wan' set log=False" self.do_module_test(obj, command=command) def test_rule_update_tcpflags_any_yes(self): - """ test updating log of a rule to yes """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', protocol='tcp', tcpflags_any='yes') + """test updating log of a rule to yes""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + protocol="tcp", + tcpflags_any="yes", + ) command = "update rule 'test_rule' on 'wan' set tcpflags_any=True" self.do_module_test(obj, command=command) def test_rule_update_tcpflags_any_no(self): - """ test updating log of a rule to no """ - obj = dict(name='test_rule_4', source='any', destination='any', interface='lan_100', tcpflags_any='no') + """test updating log of a rule to no""" + obj = dict( + name="test_rule_4", + source="any", + destination="any", + interface="lan_100", + tcpflags_any="no", + ) command = "update rule 'test_rule_4' on 'lan_100' set tcpflags_any=False" self.do_module_test(obj, command=command) def test_rule_update_log_default(self): - """ test updating log of a rule to default """ - obj = dict(name='test_rule_2', source='any', destination='any', interface='wan', protocol='tcp') + """test updating log of a rule to default""" + obj = dict( + name="test_rule_2", + source="any", + destination="any", + interface="wan", + protocol="tcp", + ) command = "update rule 'test_rule_2' on 'wan' set log=False" self.do_module_test(obj, command=command) def test_rule_update_negate_add_source(self): - """ test updating source of a rule with a not """ - obj = dict(name='test_rule_2', source='!srv_admin', destination='any', interface='wan', protocol='tcp', log=True) + """test updating source of a rule with a not""" + obj = dict( + name="test_rule_2", + source="!srv_admin", + destination="any", + interface="wan", + protocol="tcp", + log=True, + ) command = "update rule 'test_rule_2' on 'wan' set source='!srv_admin'" self.do_module_test(obj, command=command) def test_rule_update_negate_add_destination(self): - """ test updating destination of a rule with a not """ - obj = dict(name='test_rule_2', source='any', destination='!srv_admin', interface='wan', protocol='tcp', log=True) + """test updating destination of a rule with a not""" + obj = dict( + name="test_rule_2", + source="any", + destination="!srv_admin", + interface="wan", + protocol="tcp", + log=True, + ) command = "update rule 'test_rule_2' on 'wan' set destination='!srv_admin'" self.do_module_test(obj, command=command) def test_rule_update_negate_remove_source(self): - """ test updating source of a rule remove the not """ - obj = dict(name='not_rule_src', source='srv_admin', destination='any:port_ssh', interface='lan', protocol='tcp') + """test updating source of a rule remove the not""" + obj = dict( + name="not_rule_src", + source="srv_admin", + destination="any:port_ssh", + interface="lan", + protocol="tcp", + ) command = "update rule 'not_rule_src' on 'lan' set source='srv_admin'" self.do_module_test(obj, command=command) def test_rule_update_negate_remove_destination(self): - """ test updating destination of a rule remove the not """ - obj = dict(name='not_rule_dst', source='any', destination='srv_admin:port_ssh', interface='lan', protocol='tcp') + """test updating destination of a rule remove the not""" + obj = dict( + name="not_rule_dst", + source="any", + destination="srv_admin:port_ssh", + interface="lan", + protocol="tcp", + ) command = "update rule 'not_rule_dst' on 'lan' set destination='srv_admin'" self.do_module_test(obj, command=command) def test_rule_update_before(self): - """ test updating position of a rule to before another """ - obj = dict(name='test_rule_3', source='any', destination='any:port_http', interface='wan', protocol='tcp', before='test_rule') + """test updating position of a rule to before another""" + obj = dict( + name="test_rule_3", + source="any", + destination="any:port_http", + interface="wan", + protocol="tcp", + before="test_rule", + ) command = "update rule 'test_rule_3' on 'wan' set before='test_rule'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) def test_rule_update_before_bottom(self): - """ test updating position of a rule to bottom """ - obj = dict(name='test_rule_3', source='any', destination='any:port_http', interface='wan', protocol='tcp', before='bottom') + """test updating position of a rule to bottom""" + obj = dict( + name="test_rule_3", + source="any", + destination="any:port_http", + interface="wan", + protocol="tcp", + before="bottom", + ) command = "update rule 'test_rule_3' on 'wan' set before='bottom'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 3) def test_rule_update_after(self): - """ test updating position of a rule to after another rule """ - obj = dict(name='test_rule_3', source='any', destination='any:port_http', interface='wan', protocol='tcp', after='antilock_out_3') + """test updating position of a rule to after another rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any:port_http", + interface="wan", + protocol="tcp", + after="antilock_out_3", + ) command = "update rule 'test_rule_3' on 'wan' set after='antilock_out_3'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 3) def test_rule_update_after_self(self): - """ test updating position of a rule to after same rule """ - obj = dict(name='test_rule_3', source='any', destination='any', interface='wan', protocol='tcp', after='test_rule_3') - msg = 'Cannot specify the current rule in after' + """test updating position of a rule to after same rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any", + interface="wan", + protocol="tcp", + after="test_rule_3", + ) + msg = "Cannot specify the current rule in after" self.do_module_test(obj, failed=True, msg=msg) def test_rule_update_before_self(self): - """ test updating position of a rule to before same rule """ - obj = dict(name='test_rule_3', source='any', destination='any', interface='wan', protocol='tcp', before='test_rule_3') - msg = 'Cannot specify the current rule in before' + """test updating position of a rule to before same rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any", + interface="wan", + protocol="tcp", + before="test_rule_3", + ) + msg = "Cannot specify the current rule in before" self.do_module_test(obj, failed=True, msg=msg) def test_rule_update_after_top(self): - """ test updating position of a rule to top """ - obj = dict(name='test_rule_3', source='any', destination='any:port_http', interface='wan', protocol='tcp', after='top') + """test updating position of a rule to top""" + obj = dict( + name="test_rule_3", + source="any", + destination="any:port_http", + interface="wan", + protocol="tcp", + after="top", + ) command = "update rule 'test_rule_3' on 'wan' set after='top'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) def test_rule_update_separator_top(self): - """ test updating position of a rule to top """ - obj = dict(name='r2', source='any', destination='any', interface='vt1', protocol='tcp', after='top') + """test updating position of a rule to top""" + obj = dict( + name="r2", + source="any", + destination="any", + interface="vt1", + protocol="tcp", + after="top", + ) command = "update rule 'r2' on 'vt1' set after='top'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) - self.check_separator_idx(obj['interface'], 'test_sep1', 1) - self.check_separator_idx(obj['interface'], 'test_sep2', 3) + self.check_separator_idx(obj["interface"], "test_sep1", 1) + self.check_separator_idx(obj["interface"], "test_sep2", 3) def test_rule_update_separator_bottom(self): - """ test updating position of a rule to bottom """ - obj = dict(name='r1', source='any', destination='any', interface='vt1', protocol='tcp', before='bottom') + """test updating position of a rule to bottom""" + obj = dict( + name="r1", + source="any", + destination="any", + interface="vt1", + protocol="tcp", + before="bottom", + ) command = "update rule 'r1' on 'vt1' set before='bottom'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 2) - self.check_separator_idx(obj['interface'], 'test_sep1', 0) - self.check_separator_idx(obj['interface'], 'test_sep2', 2) + self.check_separator_idx(obj["interface"], "test_sep1", 0) + self.check_separator_idx(obj["interface"], "test_sep2", 2) def test_rule_update_separator_before_first(self): - """ test creation of a new rule at bottom """ - obj = dict(name='r3', source='any', destination='any', interface='vt1', protocol='tcp', before='r1') + """test creation of a new rule at bottom""" + obj = dict( + name="r3", + source="any", + destination="any", + interface="vt1", + protocol="tcp", + before="r1", + ) command = "update rule 'r3' on 'vt1' set before='r1'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 0) - self.check_separator_idx(obj['interface'], 'test_sep1', 0) - self.check_separator_idx(obj['interface'], 'test_sep2', 3) + self.check_separator_idx(obj["interface"], "test_sep1", 0) + self.check_separator_idx(obj["interface"], "test_sep2", 3) def test_rule_update_separator_after_third(self): - """ test creation of a new rule at bottom """ - obj = dict(name='r1', source='any', destination='any', interface='vt1', protocol='tcp', after='r3') + """test creation of a new rule at bottom""" + obj = dict( + name="r1", + source="any", + destination="any", + interface="vt1", + protocol="tcp", + after="r3", + ) command = "update rule 'r1' on 'vt1' set after='r3'" self.do_module_test(obj, command=command) self.check_rule_idx(obj, 2) - self.check_separator_idx(obj['interface'], 'test_sep1', 0) - self.check_separator_idx(obj['interface'], 'test_sep2', 3) + self.check_separator_idx(obj["interface"], "test_sep1", 0) + self.check_separator_idx(obj["interface"], "test_sep2", 3) def test_rule_update_queue_set(self): - """ test updating queue of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', queue='one_queue', protocol='tcp') + """test updating queue of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + queue="one_queue", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set queue='one_queue'" self.do_module_test(obj, command=command) def test_rule_update_queue_set_ack(self): - """ test updating queue and ackqueue of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', queue='one_queue', ackqueue='another_queue', protocol='tcp') + """test updating queue and ackqueue of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + queue="one_queue", + ackqueue="another_queue", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set queue='one_queue', ackqueue='another_queue'" self.do_module_test(obj, command=command) def test_rule_update_queue_unset_ack(self): - """ test updating ackqueue of a rule """ - obj = dict(name='test_lan_100_2', source='any', destination='any', interface='lan_100', queue='one_queue', protocol='tcp') + """test updating ackqueue of a rule""" + obj = dict( + name="test_lan_100_2", + source="any", + destination="any", + interface="lan_100", + queue="one_queue", + protocol="tcp", + ) command = "update rule 'test_lan_100_2' on 'lan_100' set ackqueue=none" self.do_module_test(obj, command=command) def test_rule_update_queue_unset(self): - """ test updating queue of a rule """ - obj = dict(name='test_lan_100_3', source='any', destination='any', interface='lan_100', protocol='tcp') + """test updating queue of a rule""" + obj = dict( + name="test_lan_100_3", + source="any", + destination="any", + interface="lan_100", + protocol="tcp", + ) command = "update rule 'test_lan_100_3' on 'lan_100' set queue=none" self.do_module_test(obj, command=command) def test_rule_update_limiter_set(self): - """ test updating limiter of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', in_queue='one_limiter', protocol='tcp') + """test updating limiter of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + in_queue="one_limiter", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set in_queue='one_limiter'" self.do_module_test(obj, command=command) def test_rule_update_limiter_set_out(self): - """ test updating limiter in and out of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', in_queue='one_limiter', out_queue='another_limiter', protocol='tcp') + """test updating limiter in and out of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + in_queue="one_limiter", + out_queue="another_limiter", + protocol="tcp", + ) command = "update rule 'test_rule' on 'wan' set in_queue='one_limiter', out_queue='another_limiter'" self.do_module_test(obj, command=command) def test_rule_update_limiter_unset_out(self): - """ test updating limiter out of a rule """ - obj = dict(name='test_lan_100_4', source='any', destination='any', interface='lan_100', in_queue='one_limiter', protocol='tcp') + """test updating limiter out of a rule""" + obj = dict( + name="test_lan_100_4", + source="any", + destination="any", + interface="lan_100", + in_queue="one_limiter", + protocol="tcp", + ) command = "update rule 'test_lan_100_4' on 'lan_100' set out_queue=none" self.do_module_test(obj, command=command) def test_rule_update_limiter_unset(self): - """ test updating limiter of a rule """ - obj = dict(name='test_lan_100_5', source='any', destination='any', interface='lan_100', protocol='tcp') + """test updating limiter of a rule""" + obj = dict( + name="test_lan_100_5", + source="any", + destination="any", + interface="lan_100", + protocol="tcp", + ) command = "update rule 'test_lan_100_5' on 'lan_100' set in_queue=none" self.do_module_test(obj, command=command) def test_rule_update_gateway_set(self): - """ test updating gateway of a rule """ - obj = dict(name='test_rule_3', source='any', destination='any:port_http', interface='wan', protocol='tcp', gateway='GW_WAN') + """test updating gateway of a rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any:port_http", + interface="wan", + protocol="tcp", + gateway="GW_WAN", + ) command = "update rule 'test_rule_3' on 'wan' set gateway='GW_WAN'" self.do_module_test(obj, command=command) def test_rule_update_gateway_unset(self): - """ test updating gateway of a rule """ - obj = dict(name='antilock_out_1', source='any', destination='any:port_ssh', interface='lan', protocol='tcp', log=True) + """test updating gateway of a rule""" + obj = dict( + name="antilock_out_1", + source="any", + destination="any:port_ssh", + interface="lan", + protocol="tcp", + log=True, + ) command = "update rule 'antilock_out_1' on 'lan' set gateway=none" self.do_module_test(obj, command=command) def test_rule_update_tracker(self): - """ test updating tracker of a rule """ - obj = dict(name='test_lan_100_5', source='any', destination='any', interface='lan_100', in_queue='one_limiter', protocol='tcp', tracker='1234') + """test updating tracker of a rule""" + obj = dict( + name="test_lan_100_5", + source="any", + destination="any", + interface="lan_100", + in_queue="one_limiter", + protocol="tcp", + tracker="1234", + ) command = "update rule 'test_lan_100_5' on 'lan_100' set tracker='1234'" self.do_module_test(obj, command=command) def test_rule_update_icmp(self): - """ test updating ipprotocol to icmptype """ - obj = dict(name='r1', source='any', destination='any', interface='vt1', protocol='icmp', icmptype='echorep,echoreq') - command = "update rule 'r1' on 'vt1' set protocol='icmp', icmptype='echorep,echoreq'" + """test updating ipprotocol to icmptype""" + obj = dict( + name="r1", + source="any", + destination="any", + interface="vt1", + protocol="icmp", + icmptype="echorep,echoreq", + ) + command = ( + "update rule 'r1' on 'vt1' set protocol='icmp', icmptype='echorep,echoreq'" + ) self.do_module_test(obj, command=command) def test_rule_update_port_old_syntax(self): - """ test updating gateway of a rule """ - obj = dict(name='test_rule_3', source='any', destination='any:port_ssh', interface='wan', protocol='tcp') + """test updating gateway of a rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any:port_ssh", + interface="wan", + protocol="tcp", + ) command = "update rule 'test_rule_3' on 'wan' set destination_port='port_ssh'" self.do_module_test(obj, command=command) def test_rule_update_port_new_syntax(self): - """ test updating gateway of a rule """ - obj = dict(name='test_rule_3', source='any', destination='any', destination_port='port_ssh', interface='wan', protocol='tcp') + """test updating gateway of a rule""" + obj = dict( + name="test_rule_3", + source="any", + destination="any", + destination_port="port_ssh", + interface="wan", + protocol="tcp", + ) command = "update rule 'test_rule_3' on 'wan' set destination_port='port_ssh'" self.do_module_test(obj, command=command) def test_rule_update_schedule(self): - """ test updating scheduling of a rule """ - obj = dict(name='test_rule', source='any', destination='any', interface='wan', action='pass', protocol='tcp', sched='workdays') + """test updating scheduling of a rule""" + obj = dict( + name="test_rule", + source="any", + destination="any", + interface="wan", + action="pass", + protocol="tcp", + sched="workdays", + ) command = "update rule 'test_rule' on 'wan' set sched='workdays'" self.do_module_test(obj, command=command) def test_rule_update_remove_schedule(self): - """ test updating scheduling of a rule """ - obj = dict(name='test_rule_sched', source='any', destination='any', interface='lan_100', action='pass', protocol='tcp') + """test updating scheduling of a rule""" + obj = dict( + name="test_rule_sched", + source="any", + destination="any", + interface="lan_100", + action="pass", + protocol="tcp", + ) command = "update rule 'test_rule_sched' on 'lan_100' set sched=none" self.do_module_test(obj, command=command) diff --git a/tests/unit/plugins/modules/test_pfsense_setup.py b/tests/unit/plugins/modules/test_pfsense_setup.py index 263e9efa..8bbc1a69 100644 --- a/tests/unit/plugins/modules/test_pfsense_setup.py +++ b/tests/unit/plugins/modules/test_pfsense_setup.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -12,42 +13,76 @@ from ansible_collections.pfsensible.core.plugins.modules import pfsense_setup from .pfsense_module import TestPFSenseModule -from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import ( + patch, +) class TestPFSenseSetupModule(TestPFSenseModule): - module = pfsense_setup def __init__(self, *args, **kwargs): super(TestPFSenseSetupModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_setup_config.xml' + self.config_file = "pfsense_setup_config.xml" @staticmethod def get_args_fields(): - """ return params fields """ - fields = ['hostname', 'domain', 'dns_addresses', 'dns_hostnames', 'dns_gateways', 'dnsallowoverride', 'dnslocalhost', 'timezone'] - fields += ['timeservers', 'language', 'webguicss', 'webguifixedmenu', 'webguihostnamemenu', 'dashboardcolumns', 'interfacessort'] - fields += ['dashboardavailablewidgetspanel', 'systemlogsfilterpanel', 'systemlogsmanagelogpanel', 'statusmonitoringsettingspanel'] - fields += ['requirestatefilter', 'webguileftcolumnhyper', 'disablealiaspopupdetail', 'roworderdragging', 'logincss', 'loginshowhost'] + """return params fields""" + fields = [ + "hostname", + "domain", + "dns_addresses", + "dns_hostnames", + "dns_gateways", + "dnsallowoverride", + "dnslocalhost", + "timezone", + ] + fields += [ + "timeservers", + "language", + "webguicss", + "webguifixedmenu", + "webguihostnamemenu", + "dashboardcolumns", + "interfacessort", + ] + fields += [ + "dashboardavailablewidgetspanel", + "systemlogsfilterpanel", + "systemlogsmanagelogpanel", + "statusmonitoringsettingspanel", + ] + fields += [ + "requirestatefilter", + "webguileftcolumnhyper", + "disablealiaspopupdetail", + "roworderdragging", + "logincss", + "loginshowhost", + ] return fields def setUp(self): - """ mocking up """ + """mocking up""" super(TestPFSenseSetupModule, self).setUp() # Remove validate command for webguicss which references files on the pfSense instance - self.mock_validate_webguicss = patch.dict('ansible_collections.pfsensible.core.plugins.modules.pfsense_setup.SETUP_ARG_ROUTE', - dict(webguicss=dict(parse=pfsense_setup.p2o_webguicss))) + self.mock_validate_webguicss = patch.dict( + "ansible_collections.pfsensible.core.plugins.modules.pfsense_setup.SETUP_ARG_ROUTE", + dict(webguicss=dict(parse=pfsense_setup.p2o_webguicss)), + ) self.mock_validate_webguicss.start() - self.mock_run_command = patch('ansible.module_utils.basic.AnsibleModule.run_command') + self.mock_run_command = patch( + "ansible.module_utils.basic.AnsibleModule.run_command" + ) self.run_command = self.mock_run_command.start() - self.run_command.return_value = (0, '', '') + self.run_command.return_value = (0, "", "") def tearDown(self): - """ mocking down """ + """mocking down""" super(TestPFSenseSetupModule, self).tearDown() self.mock_validate_webguicss.stop() @@ -57,12 +92,12 @@ def tearDown(self): # tests utils # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated xml definition """ - return self.assert_find_xml_elt(self.xml_result, 'system') + """get the generated xml definition""" + return self.assert_find_xml_elt(self.xml_result, "system") def check_target_elt(self, obj, target_elt): - """ test the xml definition of setup elt """ - webgui_elt = self.assert_find_xml_elt(target_elt, 'webgui') + """test the xml definition of setup elt""" + webgui_elt = self.assert_find_xml_elt(target_elt, "webgui") def check_param(param, elt): if obj.get(param) is not None: @@ -75,31 +110,33 @@ def check_bool_param(param, elt): else: self.assert_not_find_xml_elt(elt, param) - check_param('hostname', target_elt) - check_param('domain', target_elt) - check_bool_param('dnsallowoverride', target_elt) - check_param('dnslocalhost', target_elt) - check_param('timezone', target_elt) - check_param('timeservers', target_elt) - check_param('language', target_elt) - - if obj.get('webguicss') is not None: - self.assert_xml_elt_equal(webgui_elt, 'webguicss', obj['webguicss'] + '.css') - - self.check_param_bool(obj, webgui_elt, 'webguifixedmenu', value_true='fixed') - check_param('webguihostnamemenu', webgui_elt) - check_param('dashboardcolumns', webgui_elt) - check_bool_param('interfacessort', webgui_elt) - check_bool_param('dashboardavailablewidgetspanel', webgui_elt) - check_bool_param('systemlogsfilterpanel', webgui_elt) - check_bool_param('systemlogsmanagelogpanel', webgui_elt) - check_bool_param('statusmonitoringsettingspanel', webgui_elt) - check_bool_param('requirestatefilter', webgui_elt) - check_bool_param('webguileftcolumnhyper', webgui_elt) - check_bool_param('disablealiaspopupdetail', webgui_elt) - check_bool_param('roworderdragging', webgui_elt) - check_bool_param('loginshowhost', webgui_elt) - check_param('logincss', webgui_elt) + check_param("hostname", target_elt) + check_param("domain", target_elt) + check_bool_param("dnsallowoverride", target_elt) + check_param("dnslocalhost", target_elt) + check_param("timezone", target_elt) + check_param("timeservers", target_elt) + check_param("language", target_elt) + + if obj.get("webguicss") is not None: + self.assert_xml_elt_equal( + webgui_elt, "webguicss", obj["webguicss"] + ".css" + ) + + self.check_param_bool(obj, webgui_elt, "webguifixedmenu", value_true="fixed") + check_param("webguihostnamemenu", webgui_elt) + check_param("dashboardcolumns", webgui_elt) + check_bool_param("interfacessort", webgui_elt) + check_bool_param("dashboardavailablewidgetspanel", webgui_elt) + check_bool_param("systemlogsfilterpanel", webgui_elt) + check_bool_param("systemlogsmanagelogpanel", webgui_elt) + check_bool_param("statusmonitoringsettingspanel", webgui_elt) + check_bool_param("requirestatefilter", webgui_elt) + check_bool_param("webguileftcolumnhyper", webgui_elt) + check_bool_param("disablealiaspopupdetail", webgui_elt) + check_bool_param("roworderdragging", webgui_elt) + check_bool_param("loginshowhost", webgui_elt) + check_param("logincss", webgui_elt) # TODO: check dns_addresses, dns_hostnames, dns_gateways @@ -107,211 +144,245 @@ def check_bool_param(param, elt): # tests # def test_setup_hostname(self): - """ test setup hostname """ - setup = dict(hostname='acme') + """test setup hostname""" + setup = dict(hostname="acme") command = "update setup general set hostname='acme'" self.do_module_test(setup, command=command, state=None) def test_setup_hostname_invalid(self): - """ test setup hostname """ - setup = dict(hostname='acme.corp.com') - msg = "A valid hostname is specified, but the domain name part should be omitted" + """test setup hostname""" + setup = dict(hostname="acme.corp.com") + msg = ( + "A valid hostname is specified, but the domain name part should be omitted" + ) self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_hostname_invalid2(self): - """ test setup hostname """ - setup = dict(hostname='(invalid)') + """test setup hostname""" + setup = dict(hostname="(invalid)") msg = "The hostname can only contain the characters A-Z, 0-9 and '-'. It may not start or end with '-'" self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_domain(self): - """ test setup domain """ - setup = dict(domain='corp.com') + """test setup domain""" + setup = dict(domain="corp.com") command = "update setup general set domain='corp.com'" self.do_module_test(setup, command=command, state=None) def test_setup_domain_invalid(self): - """ test setup domain """ - setup = dict(domain='@invalid.com') + """test setup domain""" + setup = dict(domain="@invalid.com") msg = "The domain may only contain the characters a-z, 0-9, '-' and '.'" self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dnsallowoverride(self): - """ test setup general """ + """test setup general""" setup = dict(dnsallowoverride=False) command = "update setup general set dnsallowoverride=False" self.do_module_test(setup, command=command, state=None) def test_setup_dnslocalhost(self): - """ test setup dnslocalhost """ - setup = dict(dnslocalhost='remote') + """test setup dnslocalhost""" + setup = dict(dnslocalhost="remote") command = "update setup general set dnslocalhost='remote'" self.do_module_test(setup, command=command, state=None) def test_setup_webguifixedmenu(self): - """ test setup webguifixedmenu """ + """test setup webguifixedmenu""" setup = dict(webguifixedmenu=True) command = "update setup general set webguifixedmenu=True" self.do_module_test(setup, command=command, state=None) def test_setup_interfacessort(self): - """ test setup interfacessort """ + """test setup interfacessort""" setup = dict(interfacessort=True) command = "update setup general set interfacessort=True" self.do_module_test(setup, command=command, state=None) def test_setup_dashboardavailablewidgetspanel(self): - """ test setup dashboardavailablewidgetspanel """ + """test setup dashboardavailablewidgetspanel""" setup = dict(dashboardavailablewidgetspanel=True) command = "update setup general set dashboardavailablewidgetspanel=True" self.do_module_test(setup, command=command, state=None) def test_setup_systemlogsfilterpanel(self): - """ test setup systemlogsfilterpanel """ + """test setup systemlogsfilterpanel""" setup = dict(systemlogsfilterpanel=True) command = "update setup general set systemlogsfilterpanel=True" self.do_module_test(setup, command=command, state=None) def test_setup_systemlogsmanagelogpanel(self): - """ test setup systemlogsmanagelogpanel """ + """test setup systemlogsmanagelogpanel""" setup = dict(systemlogsmanagelogpanel=True) command = "update setup general set systemlogsmanagelogpanel=True" self.do_module_test(setup, command=command, state=None) def test_setup_statusmonitoringsettingspanel(self): - """ test setup statusmonitoringsettingspanel """ + """test setup statusmonitoringsettingspanel""" setup = dict(statusmonitoringsettingspanel=True) command = "update setup general set statusmonitoringsettingspanel=True" self.do_module_test(setup, command=command, state=None) def test_setup_requirestatefilter(self): - """ test setup requirestatefilter """ + """test setup requirestatefilter""" setup = dict(requirestatefilter=True) command = "update setup general set requirestatefilter=True" self.do_module_test(setup, command=command, state=None) def test_setup_webguileftcolumnhyper(self): - """ test setup webguileftcolumnhyper """ + """test setup webguileftcolumnhyper""" setup = dict(webguileftcolumnhyper=True) command = "update setup general set webguileftcolumnhyper=True" self.do_module_test(setup, command=command, state=None) def test_setup_disablealiaspopupdetail(self): - """ test setup disablealiaspopupdetail """ + """test setup disablealiaspopupdetail""" setup = dict(disablealiaspopupdetail=True) command = "update setup general set disablealiaspopupdetail=True" self.do_module_test(setup, command=command, state=None) def test_setup_roworderdragging(self): - """ test setup roworderdragging """ + """test setup roworderdragging""" setup = dict(roworderdragging=True) command = "update setup general set roworderdragging=True" self.do_module_test(setup, command=command, state=None) def test_setup_loginshowhost(self): - """ test setup loginshowhost """ + """test setup loginshowhost""" setup = dict(loginshowhost=True) command = "update setup general set loginshowhost=True" self.do_module_test(setup, command=command, state=None) def test_setup_language(self): - """ test setup language """ - setup = dict(language='fr_FR') + """test setup language""" + setup = dict(language="fr_FR") command = "update setup general set language='fr_FR'" self.do_module_test(setup, command=command, state=None) def test_setup_timeservers(self): - """ test setup timeservers """ - setup = dict(timeservers='1.2.3.4 0.pool.ntp.org') + """test setup timeservers""" + setup = dict(timeservers="1.2.3.4 0.pool.ntp.org") command = "update setup general set timeservers='1.2.3.4 0.pool.ntp.org'" self.do_module_test(setup, command=command, state=None) def test_setup_timezone(self): - """ test setup timezone """ - setup = dict(timezone='Europe/Paris') + """test setup timezone""" + setup = dict(timezone="Europe/Paris") command = "update setup general set timezone='Europe/Paris'" self.do_module_test(setup, command=command, state=None) def test_setup_webguicss(self): - """ test setup webguicss """ - setup = dict(webguicss='pfSense-dark') + """test setup webguicss""" + setup = dict(webguicss="pfSense-dark") command = "update setup general set webguicss='pfSense-dark'" self.do_module_test(setup, command=command, state=None) def test_setup_webguihostnamemenu(self): - """ test setup webguihostnamemenu """ - setup = dict(webguihostnamemenu='fqdn') + """test setup webguihostnamemenu""" + setup = dict(webguihostnamemenu="fqdn") command = "update setup general set webguihostnamemenu='fqdn'" self.do_module_test(setup, command=command, state=None) def test_setup_dashboardcolumns(self): - """ test setup dashboardcolumns """ - setup = dict(dashboardcolumns='3') + """test setup dashboardcolumns""" + setup = dict(dashboardcolumns="3") command = "update setup general set dashboardcolumns='3'" self.do_module_test(setup, command=command, state=None) def test_setup_dashboardcolumns_invalid(self): - """ test setup dashboardcolumns """ - setup = dict(dashboardcolumns='0') + """test setup dashboardcolumns""" + setup = dict(dashboardcolumns="0") msg = "The submitted Dashboard Columns value is invalid." self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_logincss(self): - """ test setup logincss """ - setup = dict(logincss='ff0000') + """test setup logincss""" + setup = dict(logincss="ff0000") command = "update setup general set logincss='ff0000'" self.do_module_test(setup, command=command, state=None) def test_setup_logincss_invalid(self): - """ test setup logincss """ - setup = dict(logincss='gg0000') + """test setup logincss""" + setup = dict(logincss="gg0000") msg = "logincss must be a six digits hexadecimal string." self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dns_addresses(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.4.4 8.8.8.8', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN') + """test setup dns""" + setup = dict( + dns_addresses="8.8.4.4 8.8.8.8", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_WAN", + ) command = "update setup general set dns_addresses='8.8.4.4 8.8.8.8', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN'" self.do_module_test(setup, command=command, state=None) def test_setup_dns_addresses_invalid(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.4.4 8.8.8.8 256.255.254.253', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN') - msg = 'A valid IP address must be specified for DNS server 256.255.254.253.' + """test setup dns""" + setup = dict( + dns_addresses="8.8.4.4 8.8.8.8 256.255.254.253", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_WAN", + ) + msg = "A valid IP address must be specified for DNS server 256.255.254.253." self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dns_addresses_ipv6(self): - """ test setup dns """ - setup = dict(dns_addresses='2001::8 8.8.4.4', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN') + """test setup dns""" + setup = dict( + dns_addresses="2001::8 8.8.4.4", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_WAN", + ) command = "update setup general set dns_addresses='2001::8 8.8.4.4', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN'" self.do_module_test(setup, command=command, state=None) def test_setup_dns_addresses_invalid_ipv4(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.4.4 8.8.8.8', dns_hostnames='acme1 acme2', dns_gateways='none GW_LAN6') + """test setup dns""" + setup = dict( + dns_addresses="8.8.4.4 8.8.8.8", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_LAN6", + ) msg = 'The IPv6 gateway "GW_LAN6" can not be specified for IPv4 DNS server "8.8.8.8".' self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dns_addresses_invalid_ipv6(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.4.4 2001::8', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN') + """test setup dns""" + setup = dict( + dns_addresses="8.8.4.4 2001::8", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_WAN", + ) msg = 'The IPv4 gateway "GW_WAN" can not be specified for IPv6 DNS server "2001::8".' self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dns_addresses_invalid_gw(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.4.4 8.8.8.8', dns_hostnames='acme1 acme2', dns_gateways='none GW_ACME') + """test setup dns""" + setup = dict( + dns_addresses="8.8.4.4 8.8.8.8", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_ACME", + ) msg = 'The gateway "GW_ACME" does not exist.' self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dns_addresses_invalid_gw2(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.4.4 192.168.1.1', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN') + """test setup dns""" + setup = dict( + dns_addresses="8.8.4.4 192.168.1.1", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_WAN", + ) msg = "A gateway can not be assigned to DNS '192.168.1.1' server which is on a directly connected network." self.do_module_test(setup, msg=msg, state=None, failed=True) def test_setup_dns_addresses_duplicates(self): - """ test setup dns """ - setup = dict(dns_addresses='8.8.8.8 8.8.8.8', dns_hostnames='acme1 acme2', dns_gateways='none GW_WAN') + """test setup dns""" + setup = dict( + dns_addresses="8.8.8.8 8.8.8.8", + dns_hostnames="acme1 acme2", + dns_gateways="none GW_WAN", + ) msg = "Each configured DNS server must have a unique IP address. Remove the duplicated IP." self.do_module_test(setup, msg=msg, state=None, failed=True) diff --git a/tests/unit/plugins/modules/test_pfsense_user.py b/tests/unit/plugins/modules/test_pfsense_user.py index 21699f70..58c75fa0 100644 --- a/tests/unit/plugins/modules/test_pfsense_user.py +++ b/tests/unit/plugins/modules/test_pfsense_user.py @@ -1,7 +1,8 @@ # Copyright: (c) 2020, Orion Poplawski # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -15,76 +16,91 @@ class TestPFSenseUserModule(TestPFSenseModule): - module = pfsense_user def __init__(self, *args, **kwargs): super(TestPFSenseUserModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_user_config.xml' + self.config_file = "pfsense_user_config.xml" self.pfmodule = pfsense_user.PFSenseUserModule @staticmethod def runTest(): - """ dummy function needed to instantiate this test module from another in python 2.7 """ + """dummy function needed to instantiate this test module from another in python 2.7""" pass def get_target_elt(self, obj, absent=False, module_result=None): - """ return target elt from XML """ - root_elt = self.assert_find_xml_elt(self.xml_result, 'system') - result = root_elt.findall("user[name='{0}']".format(obj['name'])) + """return target elt from XML""" + root_elt = self.assert_find_xml_elt(self.xml_result, "system") + result = root_elt.findall("user[name='{0}']".format(obj["name"])) if len(result) == 1: return result[0] elif len(result) > 1: - self.fail('Found multiple users for name {0}.'.format(obj['name'])) + self.fail("Found multiple users for name {0}.".format(obj["name"])) else: return None def check_target_elt(self, obj, target_elt): - """ check XML definition of target elt """ + """check XML definition of target elt""" - self.check_param_equal(obj, target_elt, 'name') - self.check_param_equal(obj, target_elt, 'descr') - self.check_param_equal(obj, target_elt, 'scope', default='user') - self.check_param_equal(obj, target_elt, 'uid', default='2001') + self.check_param_equal(obj, target_elt, "name") + self.check_param_equal(obj, target_elt, "descr") + self.check_param_equal(obj, target_elt, "scope", default="user") + self.check_param_equal(obj, target_elt, "uid", default="2001") # TODO - need to load groups # self.check_param_equal(obj, target_elt, 'groups') - self.check_param_equal(obj, target_elt, 'password', xml_field='bcrypt-hash') - self.check_list_param_equal_or_not_find(obj, target_elt, 'priv') - self.check_param_equal_or_not_find(obj, target_elt, 'authorizedkeys') + self.check_param_equal(obj, target_elt, "password", xml_field="bcrypt-hash") + self.check_list_param_equal_or_not_find(obj, target_elt, "priv") + self.check_param_equal_or_not_find(obj, target_elt, "authorizedkeys") ############## # tests # def test_user_create(self): - """ test creation of a new user """ - obj = dict(name='user1', descr='User One', password='$2b$12$D2jkq4Iut3ODUBN0BCrDk.bV3J5N.MrY5YEnGvTXwxeNBkyxjbbtW') - self.do_module_test(obj, command="create user 'user1', descr='User One', uid='2001'") + """test creation of a new user""" + obj = dict( + name="user1", + descr="User One", + password="$2b$12$D2jkq4Iut3ODUBN0BCrDk.bV3J5N.MrY5YEnGvTXwxeNBkyxjbbtW", + ) + self.do_module_test( + obj, command="create user 'user1', descr='User One', uid='2001'" + ) def test_user_delete(self): - """ test deletion of a user """ - obj = dict(name='testdel') + """test deletion of a user""" + obj = dict(name="testdel") self.do_module_test(obj, command="delete user 'testdel'", delete=True) def test_user_update_noop(self): - """ test not updating a user """ - obj = dict(name='testdel', descr='Delete Me', uid='2000') + """test not updating a user""" + obj = dict(name="testdel", descr="Delete Me", uid="2000") self.do_module_test(obj, command="delete user 'testdel'", changed=False) def test_user_update_descr(self): - """ test updating descr of a user """ - obj = dict(name='testdel', descr='Keep Me', uid='2000', password='$2b$12$D2jkq4Iut3ODUBN0BCrDk.bV3J5N.MrY5YEnGvTXwxeNBkyxjbbtW', - priv=['page-dashboard-all']) + """test updating descr of a user""" + obj = dict( + name="testdel", + descr="Keep Me", + uid="2000", + password="$2b$12$D2jkq4Iut3ODUBN0BCrDk.bV3J5N.MrY5YEnGvTXwxeNBkyxjbbtW", + priv=["page-dashboard-all"], + ) self.do_module_test(obj, command="update user 'testdel' set descr='Keep Me'") ############## # misc # def test_create_user_invalid_password(self): - """ test creation of a new user with invalid password """ - obj = dict(name='user1', descr='User One', password='password') - self.do_module_test(obj, command="update user 'testdel'", failed=True, msg='Password (password) does not appear to be a bcrypt hash') + """test creation of a new user with invalid password""" + obj = dict(name="user1", descr="User One", password="password") + self.do_module_test( + obj, + command="update user 'testdel'", + failed=True, + msg="Password (password) does not appear to be a bcrypt hash", + ) def test_delete_inexistent_user(self): - """ test deletion of an inexistent user """ - obj = dict(name='nouser') - self.do_module_test(obj, state='absent', changed=False) + """test deletion of an inexistent user""" + obj = dict(name="nouser") + self.do_module_test(obj, state="absent", changed=False) diff --git a/tests/unit/plugins/modules/test_pfsense_vlan.py b/tests/unit/plugins/modules/test_pfsense_vlan.py index efa3b58b..08e0beb6 100644 --- a/tests/unit/plugins/modules/test_pfsense_vlan.py +++ b/tests/unit/plugins/modules/test_pfsense_vlan.py @@ -1,7 +1,8 @@ # Copyright: (c) 2018, Frederic Bor # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) -from __future__ import (absolute_import, division, print_function) +from __future__ import absolute_import, division, print_function + __metaclass__ = type import pytest @@ -11,17 +12,18 @@ pytestmark = pytest.mark.skip("pfSense Ansible modules require Python >= 2.7") from ansible_collections.pfsensible.core.plugins.modules import pfsense_vlan -from ansible_collections.pfsensible.core.plugins.module_utils.vlan import PFSenseVlanModule +from ansible_collections.pfsensible.core.plugins.module_utils.vlan import ( + PFSenseVlanModule, +) from .pfsense_module import TestPFSenseModule class TestPFSenseVlanModule(TestPFSenseModule): - module = pfsense_vlan def __init__(self, *args, **kwargs): super(TestPFSenseVlanModule, self).__init__(*args, **kwargs) - self.config_file = 'pfsense_vlan_config.xml' + self.config_file = "pfsense_vlan_config.xml" self.pfmodule = PFSenseVlanModule ############## @@ -29,111 +31,122 @@ def __init__(self, *args, **kwargs): # def get_target_elt(self, obj, absent=False, module_result=None): - """ get the generated vlan xml definition """ + """get the generated vlan xml definition""" elt_filter = {} - elt_filter['if'] = self.unalias_interface(obj['interface'], physical=True) - elt_filter['tag'] = str(obj['vlan_id']) + elt_filter["if"] = self.unalias_interface(obj["interface"], physical=True) + elt_filter["tag"] = str(obj["vlan_id"]) - return self.assert_has_xml_tag('vlans', elt_filter, absent=absent) + return self.assert_has_xml_tag("vlans", elt_filter, absent=absent) def check_target_elt(self, obj, target_elt): - """ test the xml definition of vlan """ + """test the xml definition of vlan""" # checking vlanif - self.assert_xml_elt_equal(target_elt, 'vlanif', '{0}.{1}'.format(self.unalias_interface(obj['interface'], physical=True), obj['vlan_id'])) + self.assert_xml_elt_equal( + target_elt, + "vlanif", + "{0}.{1}".format( + self.unalias_interface(obj["interface"], physical=True), obj["vlan_id"] + ), + ) # checking descr - if 'descr' in obj: - self.assert_xml_elt_equal(target_elt, 'descr', obj['descr']) + if "descr" in obj: + self.assert_xml_elt_equal(target_elt, "descr", obj["descr"]) else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'descr') + self.assert_xml_elt_is_none_or_empty(target_elt, "descr") # checking priority - if 'priority' in obj and obj['priority'] is not None: - self.assert_xml_elt_equal(target_elt, 'pcp', str(obj['priority'])) + if "priority" in obj and obj["priority"] is not None: + self.assert_xml_elt_equal(target_elt, "pcp", str(obj["priority"])) else: - self.assert_xml_elt_is_none_or_empty(target_elt, 'pcp') + self.assert_xml_elt_is_none_or_empty(target_elt, "pcp") ############## # tests # def test_vlan_create(self): - """ test creation of a new vlan """ - vlan = dict(vlan_id=100, interface='vmx0') + """test creation of a new vlan""" + vlan = dict(vlan_id=100, interface="vmx0") command = "create vlan 'vmx0.100', descr='', priority=''" self.do_module_test(vlan, command=command) def test_vlan_create_with_assigned_name(self): - """ test creation of a new vlan using assigned name """ - vlan = dict(vlan_id=100, interface='vpn') + """test creation of a new vlan using assigned name""" + vlan = dict(vlan_id=100, interface="vpn") command = "create vlan 'vmx2.100', descr='', priority=''" self.do_module_test(vlan, command=command) def test_vlan_create_with_friendly_name(self): - """ test creation of a new vlan using friendly name """ - vlan = dict(vlan_id=100, interface='opt2') + """test creation of a new vlan using friendly name""" + vlan = dict(vlan_id=100, interface="opt2") command = "create vlan 'vmx3.100', descr='', priority=''" self.do_module_test(vlan, command=command) def test_vlan_create_with_wrong_inteface(self): - """ test creation of a new vlan using wrong interface """ - vlan = dict(vlan_id=100, interface='opt3') + """test creation of a new vlan using wrong interface""" + vlan = dict(vlan_id=100, interface="opt3") msg = "Vlans can't be set on interface opt3" self.do_module_test(vlan, failed=True, msg=msg) def test_vlan_create_with_wrong_vlan(self): - """ test creation of a new vlan using wrong vlan_id """ - vlan = dict(vlan_id=0, interface='opt2') + """test creation of a new vlan using wrong vlan_id""" + vlan = dict(vlan_id=0, interface="opt2") msg = "vlan_id must be between 1 and 4094 on interface opt2" self.do_module_test(vlan, failed=True, msg=msg) def test_vlan_create_with_wrong_prioriy(self): - """ test creation of a new vlan using wrong priority """ - vlan = dict(vlan_id=100, interface='opt2', priority=8) + """test creation of a new vlan using wrong priority""" + vlan = dict(vlan_id=100, interface="opt2", priority=8) msg = "priority must be between 0 and 7 on interface opt2" self.do_module_test(vlan, failed=True, msg=msg) def test_vlan_create_with_priority(self): - """ test creation of a new vlan """ - vlan = dict(vlan_id=100, interface='vmx0', descr='voice') + """test creation of a new vlan""" + vlan = dict(vlan_id=100, interface="vmx0", descr="voice") command = "create vlan 'vmx0.100', descr='voice', priority=''" self.do_module_test(vlan, command=command) def test_vlan_create_with_descr(self): - """ test creation of a new vlan """ - vlan = dict(vlan_id=100, interface='vmx0', priority=5) + """test creation of a new vlan""" + vlan = dict(vlan_id=100, interface="vmx0", priority=5) command = "create vlan 'vmx0.100', descr='', priority='5'" self.do_module_test(vlan, command=command) def test_vlan_delete(self): - """ test deletion of a vlan """ - vlan = dict(vlan_id=100, interface='vmx1') + """test deletion of a vlan""" + vlan = dict(vlan_id=100, interface="vmx1") command = "delete vlan 'vmx1.100'" self.do_module_test(vlan, delete=True, command=command) def test_vlan_delete_used(self): - """ test deletion of a still used vlan """ - vlan = dict(vlan_id=1100, interface='vmx1') - self.do_module_test(vlan, delete=True, failed=True, msg='vlan 1100 on vmx1 cannot be deleted because it is still being used as an interface') + """test deletion of a still used vlan""" + vlan = dict(vlan_id=1100, interface="vmx1") + self.do_module_test( + vlan, + delete=True, + failed=True, + msg="vlan 1100 on vmx1 cannot be deleted because it is still being used as an interface", + ) def test_vlan_delete_unexistent(self): - """ test deletion of a vlan """ - vlan = dict(vlan_id=1200, interface='vmx1') + """test deletion of a vlan""" + vlan = dict(vlan_id=1200, interface="vmx1") self.do_module_test(vlan, delete=True, changed=False) def test_vlan_update_noop(self): - """ test not updating a vlan """ - vlan = dict(vlan_id=1100, interface='vmx1') + """test not updating a vlan""" + vlan = dict(vlan_id=1100, interface="vmx1") self.do_module_test(vlan, changed=False) def test_vlan_update_priority(self): - """ test updating priority """ - vlan = dict(vlan_id=1100, interface='vmx1', priority=1) + """test updating priority""" + vlan = dict(vlan_id=1100, interface="vmx1", priority=1) command = "update vlan 'vmx1.1100' set priority='1'" self.do_module_test(vlan, changed=True, command=command) def test_vlan_update_descr(self): - """ test updating descr """ - vlan = dict(vlan_id=1100, interface='vmx1', descr='test') + """test updating descr""" + vlan = dict(vlan_id=1100, interface="vmx1", descr="test") command = "update vlan 'vmx1.1100' set descr='test'" self.do_module_test(vlan, changed=True, command=command)