forked from hhugo/ppx_deriving
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathppx_deriving_eq.ml
More file actions
186 lines (171 loc) · 7.53 KB
/
ppx_deriving_eq.ml
File metadata and controls
186 lines (171 loc) · 7.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
open Longident
open Location
open Asttypes
open Parsetree
open Ast_helper
open Ast_convenience
module StringSet = Ppx_deriving.StringSet
let deriver = "eq"
let raise_errorf = Ppx_deriving.raise_errorf
type eq_options =
{
allow_std_type_shadowing: bool;
}
let default_eq_options =
{
allow_std_type_shadowing= false;
}
let parse_options options =
let option_parser acc (name, expr) =
match name with
| "allow_std_type_shadowing" -> { allow_std_type_shadowing = true }
| _ ->
raise_errorf ~loc:expr.pexp_loc "%s does not support option %s" deriver name in
List.fold_left option_parser default_eq_options options
let attr_equal attrs =
Ppx_deriving.(attrs |> attr ~deriver "equal" |> Arg.(get_attr ~deriver expr))
let argn kind =
Printf.sprintf (match kind with `lhs -> "lhs%d" | `rhs -> "rhs%d")
let pattn side typs =
List.mapi (fun i _ -> pvar (argn side i)) typs
let core_type_of_decl ~options ~path type_decl =
ignore (parse_options options);
let typ = Ppx_deriving.core_type_of_type_decl type_decl in
Ppx_deriving.poly_arrow_of_type_decl
(fun var -> [%type: [%t var] -> [%t var] -> bool])
type_decl
[%type: [%t typ] -> [%t typ] -> bool]
let sig_of_type ~options ~path type_decl =
ignore (parse_options options);
[Sig.value (Val.mk (mknoloc (Ppx_deriving.mangle_type_decl (`Prefix "equal") type_decl))
(core_type_of_decl ~options ~path type_decl))]
let rec exprsn group_def typs =
typs |> List.mapi (fun i typ ->
app (expr_of_typ group_def typ) [evar (argn `lhs i); evar (argn `rhs i)])
and expr_of_typ group_def typ =
match attr_equal typ.ptyp_attributes with
| Some fn -> fn
| None ->
match typ with
| { ptyp_desc = Ptyp_constr ({ txt = (Lident id as lid) }, args) }
when StringSet.mem id group_def ->
let equal_fn = Exp.ident (mknoloc (Ppx_deriving.mangle_lid (`Prefix "equal") lid)) in
app equal_fn (List.map (expr_of_typ group_def) args)
| [%type: _] | [%type: unit] -> [%expr fun _ _ -> true]
| [%type: int] | [%type: int32] | [%type: Int32.t]
| [%type: int64] | [%type: Int64.t] | [%type: nativeint] | [%type: Nativeint.t]
| [%type: float] | [%type: bool] | [%type: char] | [%type: string] |
[%type: String.t] | [%type: bytes] ->
[%expr (fun (a:[%t typ]) b -> a = b)]
| [%type: [%t? typ] ref] -> [%expr fun a b -> [%e expr_of_typ group_def typ] !a !b]
| [%type: [%t? typ] list] ->
[%expr
let rec loop x y =
match x, y with
| [], [] -> true
| a :: x, b :: y -> [%e expr_of_typ group_def typ] a b && loop x y
| _ -> false
in (fun x y -> loop x y)]
| [%type: [%t? typ] array] ->
[%expr fun x y ->
let rec loop i =
(i = Array.length x || [%e expr_of_typ group_def typ] x.(i) y.(i)) && loop (i + 1)
in Array.length x = Array.length y && loop 0]
| [%type: [%t? typ] option] ->
[%expr fun x y ->
match x, y with
| None, None -> true
| Some a, Some b -> [%e expr_of_typ group_def typ] a b
| _ -> false]
| { ptyp_desc = Ptyp_constr ({ txt = lid }, args) } ->
let equal_fn = Exp.ident (mknoloc (Ppx_deriving.mangle_lid (`Prefix "equal") lid)) in
app equal_fn (List.map (expr_of_typ group_def) args)
| { ptyp_desc = Ptyp_tuple typs } ->
[%expr fun [%p ptuple (pattn `lhs typs)] [%p ptuple (pattn `rhs typs)] ->
[%e exprsn group_def typs |> Ppx_deriving.(fold_exprs (binop_reduce [%expr (&&)]))]]
| { ptyp_desc = Ptyp_variant (fields, _, _); ptyp_loc } ->
let cases =
(fields |> List.map (fun field ->
let pdup f = ptuple [f "lhs"; f "rhs"] in
match field with
| Rtag (label, _, true (*empty*), []) ->
Exp.case (pdup (fun _ -> Pat.variant label None)) [%expr true]
| Rtag (label, _, false, [typ]) ->
Exp.case (pdup (fun var -> Pat.variant label (Some (pvar var))))
(app (expr_of_typ group_def typ) [evar "lhs"; evar "rhs"])
| Rinherit ({ ptyp_desc = Ptyp_constr (tname, _) } as typ) ->
Exp.case (pdup (fun var -> Pat.alias (Pat.type_ tname) (mknoloc var)))
(app (expr_of_typ group_def typ) [evar "lhs"; evar "rhs"])
| _ ->
raise_errorf ~loc:ptyp_loc "%s cannot be derived for %s"
deriver (Ppx_deriving.string_of_core_type typ))) @
[Exp.case (pvar "_") [%expr false]]
in
[%expr fun lhs rhs -> [%e Exp.match_ [%expr lhs, rhs] cases]]
| { ptyp_desc = Ptyp_var name } -> evar ("poly_"^name)
| { ptyp_desc = Ptyp_alias (typ, _) } -> expr_of_typ group_def typ
| { ptyp_loc } ->
raise_errorf ~loc:ptyp_loc "%s cannot be derived for %s"
deriver (Ppx_deriving.string_of_core_type typ)
let str_of_type ~options ~path group_def ({ ptype_loc = loc } as type_decl) =
ignore (parse_options options);
let comparator =
match type_decl.ptype_kind, type_decl.ptype_manifest with
| Ptype_abstract, Some manifest -> expr_of_typ group_def manifest
| Ptype_variant constrs, _ ->
let wildcard = match constrs with
| [] | [_] -> []
| _ :: _ :: _ -> [Exp.case (pvar "_") [%expr false]] in
let cases =
(constrs |> List.map (fun { pcd_name = { txt = name }; pcd_args = typs } ->
exprsn group_def typs |>
Ppx_deriving.(fold_exprs ~unit:[%expr true] (binop_reduce [%expr (&&)])) |>
Exp.case (ptuple [pconstr name (pattn `lhs typs);
pconstr name (pattn `rhs typs)]))) @ wildcard
in
[%expr fun lhs rhs -> [%e Exp.match_ [%expr lhs, rhs] cases]]
| Ptype_record labels, _ ->
let exprs =
labels |> List.map (fun { pld_name = { txt = name }; pld_type; pld_attributes } ->
(* combine attributes of type and label *)
let attrs = pld_type.ptyp_attributes @ pld_attributes in
let pld_type = {pld_type with ptyp_attributes=attrs} in
let field obj = Exp.field obj (mknoloc (Lident name)) in
app (expr_of_typ group_def pld_type) [field (evar "lhs"); field (evar "rhs")])
in
[%expr fun lhs rhs -> [%e exprs |> Ppx_deriving.(fold_exprs (binop_reduce [%expr (&&)]))]]
| Ptype_abstract, None ->
raise_errorf ~loc "%s cannot be derived for fully abstract types" deriver
| Ptype_open, _ ->
raise_errorf ~loc "%s cannot be derived for open types" deriver
in
let polymorphize = Ppx_deriving.poly_fun_of_type_decl type_decl in
let out_type =
Ppx_deriving.strong_type_of_type @@
core_type_of_decl ~options ~path type_decl in
let eq_var =
pvar (Ppx_deriving.mangle_type_decl (`Prefix "equal") type_decl) in
[Vb.mk (Pat.constraint_ eq_var out_type) (polymorphize comparator)]
let type_decl_str ~options ~path type_decls =
let opts = parse_options options in
let typename_set =
Ppx_deriving.extract_typename_of_type_group
deriver
~allow_shadowing:opts.allow_std_type_shadowing
type_decls in
let here_loc = (List.hd type_decls).ptype_loc in
if StringSet.mem "bool" typename_set then
raise_errorf
~loc:here_loc
"%s can't derivate types when shadowing bool (even with option)" deriver;
let code =
List.map (str_of_type ~options ~path typename_set) type_decls in
[Str.value Recursive (List.concat code)]
let () =
Ppx_deriving.(register (create deriver
~core_type:(expr_of_typ StringSet.empty)
~type_decl_str: type_decl_str
~type_decl_sig: (fun ~options ~path type_decls ->
List.concat (List.map (sig_of_type ~options ~path) type_decls))
()
))