diff options
| -rw-r--r-- | src/lib.rs | 169 | ||||
| -rw-r--r-- | tests/get.rs | 60 |
2 files changed, 114 insertions, 115 deletions
@@ -60,9 +60,9 @@ //!``` //! use get::GetCopy; //! -//! #[derive(Clone, Copy, GetCopy)] +//! #[derive(Clone, Copy, Get)] //! struct NonZeroUInt<T>( -//! #[get(method = "inner")] T +//! #[get(method = "inner", kind = "copy")] T //! ); //! //! fn non_zero_uint() { @@ -92,24 +92,25 @@ //! All supported idents are: //! * `hide` (this will disable getters for a specific field) //! +//! All supported getter kinds are: +//! * `ref` +//! * `copy` +//! * `as_ref` +//! * `deref` +//! //! # Todos //! * detect and return error for fields with conflicting attributes //! * improve error output, include span information if possible -//! * AsRef or Deref getters #[proc_macro_derive(Get, attributes(get))] pub fn get(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let parsed_input = syn::parse_macro_input!(input as syn::DeriveInput); - get::expand(&parsed_input, false).unwrap().into() -} - -#[proc_macro_derive(GetCopy, attributes(get))] -pub fn get_copy(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let parsed_input = syn::parse_macro_input!(input as syn::DeriveInput); - get::expand(&parsed_input, true).unwrap().into() + get::expand(&parsed_input).unwrap().into() } mod get { + use core::option::Option::{self, None}; + use proc_macro2::{Span, TokenStream}; use quote::{ToTokens, format_ident, quote}; use syn::{ @@ -123,13 +124,12 @@ mod get { } #[derive(Default)] - struct GetNameValueList { - method: Option<String>, - } + struct GetNameValueList(Vec<GetNameValue>); #[derive(Debug, Clone)] enum GetNameValue { Method(String), + Kind(String), } #[derive(Debug, Clone)] @@ -137,16 +137,41 @@ mod get { Hide, } - pub fn expand( - input: &DeriveInput, - is_copy: bool, - ) -> Result<TokenStream, Box<dyn std::error::Error>> { + #[derive(Debug, Clone, Copy)] + pub enum GetterKind { + Ref, + Move, + Deref, + } + + impl GetNameValueList { + fn method(&self) -> Option<String> { + self.0.iter().find_map(|x| match x { + GetNameValue::Method(name) => Some(name.clone()), + _ => None, + }) + } + + fn kind(&self) -> Option<GetterKind> { + self.0.iter().find_map(|x| match x { + GetNameValue::Kind(kind) => Some(match kind.as_str() { + "ref" => GetterKind::Ref, + "move" => GetterKind::Move, + "deref" => GetterKind::Deref, + _ => return None, + }), + _ => None, + }) + } + } + + pub fn expand(input: &DeriveInput) -> Result<TokenStream, Box<dyn std::error::Error>> { let Data::Struct(target) = &input.data else { return Err("expected struct as derive input".into()); }; let getters = match &target.fields { - Fields::Unnamed(fields) => expand_for_tuple_struct(fields.unnamed.iter(), is_copy)?, - Fields::Named(fields) => expand_for_struct(fields.named.iter(), is_copy)?, + Fields::Unnamed(fields) => expand_for_tuple_struct(fields.unnamed.iter())?, + Fields::Named(fields) => expand_for_struct(fields.named.iter())?, _ => return Err("can not generate getters on a unit struct".into()), }; let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); @@ -161,7 +186,6 @@ mod get { fn expand_for_struct<'a>( fields: impl Iterator<Item = &'a Field>, - is_copy: bool, ) -> Result<TokenStream, Box<dyn std::error::Error>> { let mut tokens = TokenStream::new(); for field in fields { @@ -175,24 +199,21 @@ mod get { { continue; } - Some(Ok(GetAttribute::NameValueList(list))) => { - let method_name = list - .method - .map(|s| format_ident!("{s}")) - .unwrap_or(field.ident.as_ref().cloned().unwrap()); - expand_getter( - field, - &Member::Named(field.ident.as_ref().cloned().unwrap()), - &method_name, - is_copy, - ) - } + Some(Ok(GetAttribute::NameValueList(list))) => expand_getter( + field, + &Member::Named(field.ident.as_ref().cloned().unwrap()), + &list + .method() + .map(|name| format_ident! {"{name}"}) + .unwrap_or(field.ident.as_ref().cloned().unwrap()), + list.kind().unwrap_or(GetterKind::Ref), + ), Some(Err(e)) => return Err(e), _ => expand_getter( field, &Member::Named(field.ident.as_ref().cloned().unwrap()), field.ident.as_ref().unwrap(), - is_copy, + GetterKind::Ref, ), } .to_tokens(&mut tokens); @@ -202,7 +223,6 @@ mod get { fn expand_for_tuple_struct<'a>( fields: impl Iterator<Item = &'a Field>, - is_copy: bool, ) -> Result<TokenStream, Box<dyn std::error::Error>> { let mut tokens = TokenStream::new(); for (i, field) in fields.enumerate() { @@ -216,15 +236,15 @@ mod get { { continue; } - Some(Ok(GetAttribute::NameValueList(list))) if list.method.is_some() => { + Some(Ok(GetAttribute::NameValueList(list))) if list.method().is_some() => { expand_getter( field, &Member::Unnamed(Index { index: i.try_into().unwrap(), span: Span::call_site(), }), - &list.method.map(|s| format_ident!("{s}")).unwrap(), - is_copy, + &list.method().map(|s| format_ident!("{s}")).unwrap(), + list.kind().unwrap_or(GetterKind::Ref), ) } Some(Err(e)) => return Err(e), @@ -239,28 +259,45 @@ mod get { field: &Field, field_name: &Member, method_name: &Ident, - is_copy: bool, + kind: GetterKind, ) -> TokenStream { let field_type = &field.ty; + let field_lifetime = match &field.ty { Type::Reference(type_ref) => Some(&type_ref.lifetime), _ => None, }; - let method_args = if is_copy { + + let method_args = if let GetterKind::Move = kind { quote! { ( self ) } } else { quote! { ( & #field_lifetime self ) } }; - let method_type = if is_copy { - quote! { #field_type } - } else { - quote! { & #field_type } + + let method_type = match kind { + GetterKind::Ref => quote! { + & #field_type + }, + GetterKind::Move => quote! { + #field_type + }, + GetterKind::Deref => quote! { + & < #field_type as :: std :: ops :: Deref > :: Target + }, }; - let method_body = if is_copy { - quote! { { self . #field_name } } - } else { - quote! { { & self . #field_name } } + + let method_body = match kind { + GetterKind::Ref => quote! { + { & self . #field_name } + }, + GetterKind::Move => quote! { + { self . #field_name } + }, + GetterKind::Deref => quote! { + { < #field_type as std :: ops :: Deref > :: deref ( &self . #field_name ) } + }, }; + quote! { pub fn #method_name @@ -299,24 +336,19 @@ mod get { impl TryFrom<Punctuated<MetaNameValue, Token![,]>> for GetNameValueList { type Error = Box<dyn std::error::Error>; fn try_from(punct: Punctuated<MetaNameValue, Token![,]>) -> Result<Self, Self::Error> { - Ok(punct - .into_iter() - .map(GetNameValue::try_from) - .collect::<Result<Vec<GetNameValue>, _>>() - .map(|v| { - if !v.is_empty() { - Ok::<Vec<GetNameValue>, Box<dyn std::error::Error>>(v) - } else { - Err("expected at least 1 name value pair in attribute".into()) - } - })?? - .into_iter() - .fold(Self::default(), |acc, n| match n { - GetNameValue::Method(m) => Self { - method: Some(m), - ..acc - }, - })) + Ok(Self( + punct + .into_iter() + .map(GetNameValue::try_from) + .collect::<Result<Vec<GetNameValue>, _>>() + .map(|v| { + if !v.is_empty() { + Ok::<Vec<GetNameValue>, Box<dyn std::error::Error>>(v) + } else { + Err("expected at least 1 name value pair in attribute".into()) + } + })??, + )) } } @@ -336,9 +368,12 @@ mod get { if let Some(name) = meta.path.get_ident().map(|ident| ident.to_string()) && let Expr::Lit(expr) = &meta.value && let Lit::Str(s) = &expr.lit - && let "method" = name.as_str() { - Ok(Self::Method(s.value())) + match name.as_str() { + "method" => Ok(Self::Method(s.value())), + "kind" => Ok(Self::Kind(s.value())), + _ => Err(format!("invalid value in attribute: {}", name.as_str()).into()), + } } else { Err("invalid name value pair in attribute".into()) } diff --git a/tests/get.rs b/tests/get.rs index 8ff2c24..ccb6d9d 100644 --- a/tests/get.rs +++ b/tests/get.rs @@ -16,74 +16,38 @@ fn trybuild() { mod get { use get::Get; #[derive(Get)] - pub struct Cat<'a, T> { - name: &'a str, + pub struct Cat<T> { + #[get(kind = "deref")] + name: String, + #[get(kind = "move")] age: u64, owner: T, } - #[derive(Get)] - pub struct CatTuple<'a, T>( - #[get(method = "name")] &'a str, - #[get(method = "age")] u64, - #[get(method = "owner")] T, - ); - - #[test] - fn cat_struct() { - let cat = Cat { - name: "cat", - age: 1, - owner: (), - }; - assert_eq!(*cat.name(), "cat"); - assert_eq!(*cat.age(), 1); - assert!(matches!(cat.owner(), ())); - } - - #[test] - fn cat_tuple_struct() { - let cat = CatTuple("cat", 1, ()); - assert_eq!(*cat.name(), "cat"); - assert_eq!(*cat.age(), 1); - assert!(matches!(cat.owner(), ())); - } -} - -mod get_copy { - use get::GetCopy; - - #[derive(Clone, Copy, GetCopy)] - pub struct Cat<'a, T> { - name: &'a str, - age: u64, - owner: T, - } - - #[derive(Clone, Copy, GetCopy)] - pub struct CatTuple<'a, T>( - #[get(method = "name")] &'a str, - #[get(method = "age")] u64, + #[derive(Clone, Get)] + pub struct CatTuple<T>( + #[get(method = "name", kind = "deref")] String, + #[get(method = "age", kind = "move")] u64, #[get(method = "owner")] T, ); #[test] fn cat_struct() { let cat = Cat { - name: "cat", + name: "cat".to_string(), age: 1, owner: (), }; assert_eq!(cat.name(), "cat"); - assert_eq!(cat.age(), 1); assert!(matches!(cat.owner(), ())); + assert_eq!(cat.age(), 1); } #[test] fn cat_tuple_struct() { - let cat = CatTuple("cat", 1, ()); + let cat = CatTuple("cat".to_string(), 1, ()); assert_eq!(cat.name(), "cat"); - assert_eq!(cat.age(), 1); assert!(matches!(cat.owner(), ())); + assert_eq!(cat.age(), 1); } } |
