summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJohn Turner <jturner.usa@gmail.com>2025-10-26 20:20:06 -0400
committerJohn Turner <jturner.usa@gmail.com>2025-10-26 20:20:06 -0400
commit287090c95ae6200bd60bd50021f6ca05130a5cc6 (patch)
tree22fbbc1f4f9fdfca97e02c48154127c86951f14c /src
parentc5c6b646a5605221e476b4ea6cce6eccb1779552 (diff)
downloadget-287090c95ae6200bd60bd50021f6ca05130a5cc6.tar.gz
impl deref getters
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs169
1 files changed, 102 insertions, 67 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 7154b6c..83eb741 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -60,9 +60,9 @@
//!```
//! use get::GetCopy;
//!
-//! #[derive(Clone, Copy, GetCopy)]
+//! #[derive(Clone, Copy, Get)]
//! struct NonZeroUInt<T>(
-//! #[get(method = "inner")] T
+//! #[get(method = "inner", kind = "copy")] T
//! );
//!
//! fn non_zero_uint() {
@@ -92,24 +92,25 @@
//! All supported idents are:
//! * `hide` (this will disable getters for a specific field)
//!
+//! All supported getter kinds are:
+//! * `ref`
+//! * `copy`
+//! * `as_ref`
+//! * `deref`
+//!
//! # Todos
//! * detect and return error for fields with conflicting attributes
//! * improve error output, include span information if possible
-//! * AsRef or Deref getters
#[proc_macro_derive(Get, attributes(get))]
pub fn get(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let parsed_input = syn::parse_macro_input!(input as syn::DeriveInput);
- get::expand(&parsed_input, false).unwrap().into()
-}
-
-#[proc_macro_derive(GetCopy, attributes(get))]
-pub fn get_copy(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
- let parsed_input = syn::parse_macro_input!(input as syn::DeriveInput);
- get::expand(&parsed_input, true).unwrap().into()
+ get::expand(&parsed_input).unwrap().into()
}
mod get {
+ use core::option::Option::{self, None};
+
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, format_ident, quote};
use syn::{
@@ -123,13 +124,12 @@ mod get {
}
#[derive(Default)]
- struct GetNameValueList {
- method: Option<String>,
- }
+ struct GetNameValueList(Vec<GetNameValue>);
#[derive(Debug, Clone)]
enum GetNameValue {
Method(String),
+ Kind(String),
}
#[derive(Debug, Clone)]
@@ -137,16 +137,41 @@ mod get {
Hide,
}
- pub fn expand(
- input: &DeriveInput,
- is_copy: bool,
- ) -> Result<TokenStream, Box<dyn std::error::Error>> {
+ #[derive(Debug, Clone, Copy)]
+ pub enum GetterKind {
+ Ref,
+ Move,
+ Deref,
+ }
+
+ impl GetNameValueList {
+ fn method(&self) -> Option<String> {
+ self.0.iter().find_map(|x| match x {
+ GetNameValue::Method(name) => Some(name.clone()),
+ _ => None,
+ })
+ }
+
+ fn kind(&self) -> Option<GetterKind> {
+ self.0.iter().find_map(|x| match x {
+ GetNameValue::Kind(kind) => Some(match kind.as_str() {
+ "ref" => GetterKind::Ref,
+ "move" => GetterKind::Move,
+ "deref" => GetterKind::Deref,
+ _ => return None,
+ }),
+ _ => None,
+ })
+ }
+ }
+
+ pub fn expand(input: &DeriveInput) -> Result<TokenStream, Box<dyn std::error::Error>> {
let Data::Struct(target) = &input.data else {
return Err("expected struct as derive input".into());
};
let getters = match &target.fields {
- Fields::Unnamed(fields) => expand_for_tuple_struct(fields.unnamed.iter(), is_copy)?,
- Fields::Named(fields) => expand_for_struct(fields.named.iter(), is_copy)?,
+ Fields::Unnamed(fields) => expand_for_tuple_struct(fields.unnamed.iter())?,
+ Fields::Named(fields) => expand_for_struct(fields.named.iter())?,
_ => return Err("can not generate getters on a unit struct".into()),
};
let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
@@ -161,7 +186,6 @@ mod get {
fn expand_for_struct<'a>(
fields: impl Iterator<Item = &'a Field>,
- is_copy: bool,
) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut tokens = TokenStream::new();
for field in fields {
@@ -175,24 +199,21 @@ mod get {
{
continue;
}
- Some(Ok(GetAttribute::NameValueList(list))) => {
- let method_name = list
- .method
- .map(|s| format_ident!("{s}"))
- .unwrap_or(field.ident.as_ref().cloned().unwrap());
- expand_getter(
- field,
- &Member::Named(field.ident.as_ref().cloned().unwrap()),
- &method_name,
- is_copy,
- )
- }
+ Some(Ok(GetAttribute::NameValueList(list))) => expand_getter(
+ field,
+ &Member::Named(field.ident.as_ref().cloned().unwrap()),
+ &list
+ .method()
+ .map(|name| format_ident! {"{name}"})
+ .unwrap_or(field.ident.as_ref().cloned().unwrap()),
+ list.kind().unwrap_or(GetterKind::Ref),
+ ),
Some(Err(e)) => return Err(e),
_ => expand_getter(
field,
&Member::Named(field.ident.as_ref().cloned().unwrap()),
field.ident.as_ref().unwrap(),
- is_copy,
+ GetterKind::Ref,
),
}
.to_tokens(&mut tokens);
@@ -202,7 +223,6 @@ mod get {
fn expand_for_tuple_struct<'a>(
fields: impl Iterator<Item = &'a Field>,
- is_copy: bool,
) -> Result<TokenStream, Box<dyn std::error::Error>> {
let mut tokens = TokenStream::new();
for (i, field) in fields.enumerate() {
@@ -216,15 +236,15 @@ mod get {
{
continue;
}
- Some(Ok(GetAttribute::NameValueList(list))) if list.method.is_some() => {
+ Some(Ok(GetAttribute::NameValueList(list))) if list.method().is_some() => {
expand_getter(
field,
&Member::Unnamed(Index {
index: i.try_into().unwrap(),
span: Span::call_site(),
}),
- &list.method.map(|s| format_ident!("{s}")).unwrap(),
- is_copy,
+ &list.method().map(|s| format_ident!("{s}")).unwrap(),
+ list.kind().unwrap_or(GetterKind::Ref),
)
}
Some(Err(e)) => return Err(e),
@@ -239,28 +259,45 @@ mod get {
field: &Field,
field_name: &Member,
method_name: &Ident,
- is_copy: bool,
+ kind: GetterKind,
) -> TokenStream {
let field_type = &field.ty;
+
let field_lifetime = match &field.ty {
Type::Reference(type_ref) => Some(&type_ref.lifetime),
_ => None,
};
- let method_args = if is_copy {
+
+ let method_args = if let GetterKind::Move = kind {
quote! { ( self ) }
} else {
quote! { ( & #field_lifetime self ) }
};
- let method_type = if is_copy {
- quote! { #field_type }
- } else {
- quote! { & #field_type }
+
+ let method_type = match kind {
+ GetterKind::Ref => quote! {
+ & #field_type
+ },
+ GetterKind::Move => quote! {
+ #field_type
+ },
+ GetterKind::Deref => quote! {
+ & < #field_type as :: std :: ops :: Deref > :: Target
+ },
};
- let method_body = if is_copy {
- quote! { { self . #field_name } }
- } else {
- quote! { { & self . #field_name } }
+
+ let method_body = match kind {
+ GetterKind::Ref => quote! {
+ { & self . #field_name }
+ },
+ GetterKind::Move => quote! {
+ { self . #field_name }
+ },
+ GetterKind::Deref => quote! {
+ { < #field_type as std :: ops :: Deref > :: deref ( &self . #field_name ) }
+ },
};
+
quote! {
pub fn
#method_name
@@ -299,24 +336,19 @@ mod get {
impl TryFrom<Punctuated<MetaNameValue, Token![,]>> for GetNameValueList {
type Error = Box<dyn std::error::Error>;
fn try_from(punct: Punctuated<MetaNameValue, Token![,]>) -> Result<Self, Self::Error> {
- Ok(punct
- .into_iter()
- .map(GetNameValue::try_from)
- .collect::<Result<Vec<GetNameValue>, _>>()
- .map(|v| {
- if !v.is_empty() {
- Ok::<Vec<GetNameValue>, Box<dyn std::error::Error>>(v)
- } else {
- Err("expected at least 1 name value pair in attribute".into())
- }
- })??
- .into_iter()
- .fold(Self::default(), |acc, n| match n {
- GetNameValue::Method(m) => Self {
- method: Some(m),
- ..acc
- },
- }))
+ Ok(Self(
+ punct
+ .into_iter()
+ .map(GetNameValue::try_from)
+ .collect::<Result<Vec<GetNameValue>, _>>()
+ .map(|v| {
+ if !v.is_empty() {
+ Ok::<Vec<GetNameValue>, Box<dyn std::error::Error>>(v)
+ } else {
+ Err("expected at least 1 name value pair in attribute".into())
+ }
+ })??,
+ ))
}
}
@@ -336,9 +368,12 @@ mod get {
if let Some(name) = meta.path.get_ident().map(|ident| ident.to_string())
&& let Expr::Lit(expr) = &meta.value
&& let Lit::Str(s) = &expr.lit
- && let "method" = name.as_str()
{
- Ok(Self::Method(s.value()))
+ match name.as_str() {
+ "method" => Ok(Self::Method(s.value())),
+ "kind" => Ok(Self::Kind(s.value())),
+ _ => Err(format!("invalid value in attribute: {}", name.as_str()).into()),
+ }
} else {
Err("invalid name value pair in attribute".into())
}