Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ enum Target {
use enum_convert::EnumFrom;

enum Source {
Tuple(String, u8),
Record {
name: String,
value: i32,
Expand All @@ -109,6 +110,12 @@ enum Source {
#[derive(EnumFrom)]
#[enum_from(Source)]
enum Target {
#[enum_from]
Tuple(
// We effectively re-order fields
#[enum_from(Source::Tuple.1)] u8,
#[enum_from(Source::Tuple.0)] String,
),
#[enum_from]
Record {
#[enum_from(Source::Record.name)] // Maps Source::Record.name to Target::Record.title
Expand Down
194 changes: 132 additions & 62 deletions src/enum_from/generator.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};

use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Fields, Variant};

use crate::{
enum_from::parser::{ContainerAnnotation, FieldAnnotations, ParsedEnumFrom, VariantAnnotation},
idents::{ContainerIdent, FieldIdent, VariantIdent},
enum_from::parser::{
ContainerAnnotation, FieldAnnotation, FieldAnnotations, ParsedEnumFrom, VariantAnnotation,
},
idents::{ContainerIdent, FieldIdent, FieldRef, VariantIdent},
};

/// A struct holding all the data necessary to generate a TokenStream.
Expand All @@ -19,9 +21,28 @@ pub struct EnumFromGenerator {

struct VariantsMapping(HashMap<VariantIdent, VariantMapping>);

struct VariantMapping {
target_variant: VariantIdent,
fields_mapping: HashMap<FieldIdent, FieldIdent>,
enum VariantMapping {
Unit {
target_variant: VariantIdent,
},
Tuple {
target_variant: VariantIdent,
fields_mapping: HashMap<usize, usize>,
},
Struct {
target_variant: VariantIdent,
fields_mapping: HashMap<FieldIdent, FieldIdent>,
},
}

impl VariantMapping {
fn target_variant(&self) -> &VariantIdent {
match self {
VariantMapping::Unit { target_variant } => target_variant,
VariantMapping::Tuple { target_variant, .. } => target_variant,
VariantMapping::Struct { target_variant, .. } => target_variant,
}
}
}

impl EnumFromGenerator {
Expand Down Expand Up @@ -49,22 +70,23 @@ fn generate_from_impl(
target_enum: &ContainerIdent,
target_variants: &HashMap<VariantIdent, Variant>,
) -> TokenStream {
let match_arms = variants_mapping
.0
.into_iter()
.map(|(source_variant, variant_mapping)| {
let target_variant = target_variants.get(&variant_mapping.target_variant).expect(
"All target variants in variant_mapping should be present in target_variants",
);
generate_match_arm(
source_variant,
variant_mapping,
&source_enum,
target_enum,
target_variant,
)
})
.collect::<Vec<_>>();
let match_arms =
variants_mapping
.0
.into_iter()
.map(|(source_variant, variant_mapping)| {
let target_variant = target_variants.get(variant_mapping.target_variant()).expect(
"All target variants in variant_mapping should be present in target_variants",
);
generate_match_arm(
source_variant,
variant_mapping,
&source_enum,
target_enum,
target_variant,
)
})
.collect::<Vec<_>>();

quote! {
impl From<#source_enum> for #target_enum {
Expand All @@ -84,26 +106,41 @@ fn generate_match_arm(
target_enum: &ContainerIdent,
variant: &Variant,
) -> TokenStream {
let target_variant = &variant.ident;

match &variant.fields {
Fields::Unit => quote! {
#source_enum::#source_variant => #target_enum::#target_variant,
},
Fields::Unnamed(fields) => {
let field_names: Vec<_> = (0..fields.unnamed.len())
.map(|i| quote::format_ident!("field_{}", i))
.collect();
let field_conversions: Vec<_> = field_names
.iter()
.map(|name| quote! { #name.into() })
.collect();
match (&variant.fields, variant_mapping) {
(Fields::Unit, VariantMapping::Unit { target_variant }) => {
quote! { #source_enum::#source_variant => #target_enum::#target_variant, }
}
(
Fields::Unnamed(fields),
VariantMapping::Tuple {
target_variant,
fields_mapping,
},
) => {
let (source_fields, target_fields): (Vec<_>, Vec<_>) = (0..fields.unnamed.len())
.map(|field_target_pos| {
let field_source_pos = fields_mapping
.get(&field_target_pos)
.unwrap_or(&field_target_pos);
let target_field_name = quote::format_ident!("field_{field_target_pos}");
(
quote::format_ident!("field_{field_source_pos}"),
quote! { #target_field_name.into() },
)
})
.unzip();
quote! {
#source_enum::#source_variant(#(#field_names),*) =>
#target_enum::#target_variant(#(#field_conversions),*),
#source_enum::#source_variant(#(#source_fields),*) =>
#target_enum::#target_variant(#(#target_fields),*),
}
}
Fields::Named(fields) => {
(
Fields::Named(fields),
VariantMapping::Struct {
target_variant,
fields_mapping,
},
) => {
let (source_fields, target_fields): (Vec<_>, Vec<_>) = fields
.named
.iter()
Expand All @@ -115,10 +152,7 @@ fn generate_match_arm(
.expect("A named field should always have an ident")
.clone(),
);
let source_field = &variant_mapping
.fields_mapping
.get(&target_field)
.unwrap_or(&target_field);
let source_field = &fields_mapping.get(&target_field).unwrap_or(&target_field);
(
quote! { #source_field },
quote! { #target_field: #source_field.into() },
Expand All @@ -131,6 +165,7 @@ fn generate_match_arm(
#target_enum::#target_variant { #(#target_fields),* },
}
}
(_, _) => panic!("Unexpected mixing of variant types"),
}
}

Expand Down Expand Up @@ -173,7 +208,7 @@ impl TryFrom<ParsedEnumFrom> for EnumFromGenerator {
variant_annotation,
)?;

let VariantsMapping(variant_mapping) = source_enums.get_mut(&source_enum).ok_or_else(|| {
let VariantsMapping(variants_mapping) = source_enums.get_mut(&source_enum).ok_or_else(|| {
syn::Error::new(
span,
format!(
Expand All @@ -187,14 +222,51 @@ impl TryFrom<ParsedEnumFrom> for EnumFromGenerator {
&source_enum,
&source_variant,
)?;

variant_mapping.insert(
source_variant,
VariantMapping {
target_variant: VariantIdent(target_variant.ident.clone()),
fields_mapping,
let fields = &target_variant.fields;
let target_variant = VariantIdent(target_variant.ident.clone());
let variant_mapping = match fields {
Fields::Unit => VariantMapping::Unit { target_variant },
Fields::Unnamed(_) => VariantMapping::Tuple {
target_variant,
fields_mapping: fields_mapping
.into_iter()
.map(|target_to_source| match target_to_source {
(
FieldRef::FieldPos(target_pos),
FieldAnnotation {
source_field: FieldRef::FieldPos(source_pos),
..
},
) => Ok((target_pos, source_pos)),
(_, FieldAnnotation { field_span, .. }) => Err(syn::Error::new(
field_span,
"Unexpected mapping to named field for tuple variant",
)),
})
.collect::<syn::Result<_>>()?,
},
);
Fields::Named(_) => VariantMapping::Struct {
target_variant,
fields_mapping: fields_mapping
.into_iter()
.map(|target_to_source| match target_to_source {
(
FieldRef::FieldIdent(target_ident),
FieldAnnotation {
source_field: FieldRef::FieldIdent(source_ident),
..
},
) => Ok((target_ident, source_ident)),
(_, FieldAnnotation { field_span, .. }) => Err(syn::Error::new(
field_span,
"Unexpected mapping to positional field for struct variant",
)),
})
.collect::<syn::Result<_>>()?,
},
};

variants_mapping.insert(source_variant, variant_mapping);
}
target_variants.insert(VariantIdent(target_variant.ident.clone()), target_variant);
}
Expand All @@ -208,13 +280,13 @@ impl TryFrom<ParsedEnumFrom> for EnumFromGenerator {
}

fn get_fields_mapping(
fields_annotations: &HashMap<FieldIdent, FieldAnnotations>,
fields_annotations: &HashMap<FieldRef, FieldAnnotations>,
source_enum: &ContainerIdent,
source_variant: &VariantIdent,
) -> syn::Result<HashMap<FieldIdent, FieldIdent>> {
) -> syn::Result<BTreeMap<FieldRef, FieldAnnotation>> {
Ok(fields_annotations
.iter()
.map(|(target_field, field_annotations)| {
.filter_map(|(target_field, field_annotations)| {
let annotations = field_annotations
.fields_annotations
.iter()
Expand All @@ -223,16 +295,14 @@ fn get_fields_mapping(
&& field_annotation.source_variant == *source_variant
})
.collect::<Vec<_>>();
let source_field = match annotations.len() {
0 => target_field.clone(),
1 => annotations[0].source_field.clone(),
_ => Err(syn::Error::new(
match annotations.len() {
0 => None,
1 => Some(Ok((target_field.clone(), annotations[0].clone()))),
_ => Some(Err(syn::Error::new(
field_annotations.field_span,
format!("Multiple mapping found for source enum `{source_enum}`"),
))?,
};

Ok((target_field.clone(), source_field))
))),
}
})
.collect::<syn::Result<Vec<_>>>()?
.into_iter()
Expand Down
41 changes: 31 additions & 10 deletions src/enum_from/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use std::collections::HashMap;
use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::{
Attribute, Data, DataEnum, DeriveInput, Field, Ident, Meta, Path, Token, Variant,
Attribute, Data, DataEnum, DeriveInput, Field, Ident, LitInt, Meta, Path, Token, Variant,
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
};

use crate::idents::{ContainerIdent, FieldIdent, VariantIdent};
use crate::idents::{ContainerIdent, FieldIdent, FieldRef, VariantIdent};

/// A "dumb" parser of the EnumFrom annotations
/// There is no check of consistency between annotations here.
Expand Down Expand Up @@ -47,7 +47,7 @@ pub struct ContainerAnnotation(pub ContainerIdent);

pub struct VariantAnnotations {
pub variant_annotations: Vec<VariantAnnotation>,
pub fields_annotations: HashMap<FieldIdent, FieldAnnotations>,
pub fields_annotations: HashMap<FieldRef, FieldAnnotations>,
}

pub enum VariantAnnotation {
Expand Down Expand Up @@ -94,10 +94,12 @@ pub struct FieldAnnotations {
pub field_span: Span,
}

#[derive(Clone)]
pub struct FieldAnnotation {
pub source_enum: ContainerIdent,
pub source_variant: VariantIdent,
pub source_field: FieldIdent,
pub source_field: FieldRef,
pub field_span: Span,
}

impl Parse for FieldAnnotation {
Expand All @@ -107,11 +109,22 @@ impl Parse for FieldAnnotation {
let source_enum = ContainerIdent(path.segments[0].ident.clone());
let source_variant = VariantIdent(path.segments[1].ident.clone());
input.parse::<Token![.]>()?;
let source_field = FieldIdent(input.parse()?);
let field_span = input.span();
let source_field = if let Ok(ident) = input.parse::<Ident>() {
FieldRef::FieldIdent(FieldIdent(ident))
} else if let Ok(lit) = input.parse::<LitInt>() {
FieldRef::FieldPos(lit.base10_parse()?)
} else {
Err(syn::Error::new(
field_span,
"Expected either a field identifier or a field position",
))?
};
Ok(FieldAnnotation {
source_enum,
source_variant,
source_field,
field_span,
})
} else {
Err(syn::Error::new_spanned(
Expand Down Expand Up @@ -212,11 +225,19 @@ fn extract_variant_annotations(variant: &Variant) -> syn::Result<VariantAnnotati
let fields_annotations = variant
.fields
.iter()
.filter_map(|field| {
field.ident.as_ref().map(|field_ident| {
extract_field_annotations(field)
.map(|field_annotations| (FieldIdent(field_ident.clone()), field_annotations))
})
.enumerate()
.map(|(pos, field)| {
let annotations = extract_field_annotations(field);
match &field.ident {
Some(field_ident) => annotations.map(|field_annotations| {
(
FieldRef::FieldIdent(FieldIdent(field_ident.clone())),
field_annotations,
)
}),
None => annotations
.map(|field_annotations| (FieldRef::FieldPos(pos), field_annotations)),
}
})
.collect::<syn::Result<Vec<_>>>()?
.into_iter()
Expand Down
Loading