From 8a4130dea3087ca8e528322bf7898f91c1043fce Mon Sep 17 00:00:00 2001 From: Marceline Cramer Date: Fri, 24 Apr 2026 22:27:28 -0600 Subject: [PATCH] Add Salsa support to IR --- Cargo.lock | 1 + crates/ir/Cargo.toml | 2 ++ crates/ir/src/lib.rs | 47 ++++++++++++++++++++++++++++++++++++-------- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e6769f4..173b9e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -587,6 +587,7 @@ dependencies = [ "arbitrary", "derive-where", "ordered-float", + "salsa", "serde", ] diff --git a/crates/ir/Cargo.toml b/crates/ir/Cargo.toml index e163821..1acc33e 100644 --- a/crates/ir/Cargo.toml +++ b/crates/ir/Cargo.toml @@ -11,9 +11,11 @@ rust-version.workspace = true arbitrary = { workspace = true, optional = true } derive-where = { workspace = true, features = ["serde"] } ordered-float.workspace = true +salsa = { workspace = true, optional = true } serde = { workspace = true, optional = true } # TODO: test with feature power set [features] fuzz = ["dep:arbitrary", "ordered-float/arbitrary"] +salsa = ["dep:salsa"] serde = ["dep:serde"] diff --git a/crates/ir/src/lib.rs b/crates/ir/src/lib.rs index af0bc72..423ebb9 100644 --- a/crates/ir/src/lib.rs +++ b/crates/ir/src/lib.rs @@ -31,13 +31,16 @@ use ordered_float::OrderedFloat; #[cfg(feature = "fuzz")] use arbitrary::Arbitrary; +#[cfg(feature = "salsa")] +use salsa::Update; + #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; // TODO: write more top-level docs // TODO: reintroduce validation logic -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -52,7 +55,7 @@ pub struct Program { pub symbols: Vec, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -79,7 +82,7 @@ pub struct Relation { pub rules: Vec>, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -95,7 +98,7 @@ pub struct Rule { pub body: RuleBody, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -111,7 +114,7 @@ pub struct Assumption { pub body: RuleBody, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -132,7 +135,7 @@ pub struct RuleBody { pub body: Expr, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -142,7 +145,7 @@ pub struct Expr { pub kind: ExprKind, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive_where(Clone, Debug, PartialEq, Eq)] #[derive_where(PartialOrd, Ord; T: OrdProgramInfo)] #[derive_where(Hash; T: HashProgramInfo)] #[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))] @@ -314,7 +317,7 @@ impl Value { /// Associated types specializing a [Program] with stage-specific or format-specific information. // TODO: document how each type's usage changes based on stage and format -pub trait ProgramInfo: Info { +pub trait ProgramInfo { /// The type of global symbol labels. type SymbolLabel: Info; @@ -339,6 +342,8 @@ pub trait ProgramInfo: Info { /// A blanket trait bounding all associated types in [ProgramInfo]. pub trait Info: Clone + Debug + Eq {} +impl Info for T {} + /// Creates a new trait to further bound the members of [ProgramInfo]. macro_rules! def_bound_info { ($name:ident, $bounds:tt) => { @@ -375,6 +380,32 @@ def_bound_info!(HashProgramInfo, Hash); #[cfg(feature = "fuzz")] def_bound_info!(ArbitraryProgramInfo, (for<'a> Arbitrary<'a>)); +#[cfg(feature = "salsa")] +macro_rules! impl_update { + () => {}; + ($head:ident, $($tail:ident),*) => { + impl_update!($head); + impl_update!($($tail),*); + }; + ($item:ident) => { + unsafe impl Update for $item { + unsafe fn maybe_update(old: *mut Self, new: Self) -> bool { + let old: &mut Self = unsafe { &mut *old }; + + if *old == new { + false + } else { + let _ = std::mem::replace(old, new); + true + } + } + } + }; +} + +#[cfg(feature = "salsa")] +impl_update!(Program, Relation, Rule, RuleBody, Assumption, Expr, ExprKind); + #[cfg(feature = "serde")] def_bound_info!(DeserializeProgramInfo, (for<'a> Deserialize<'a>));