Add Salsa support to IR

This commit is contained in:
2026-04-24 22:27:28 -06:00
parent 81d43d2e13
commit 8a4130dea3
3 changed files with 42 additions and 8 deletions

View File

@@ -11,9 +11,11 @@ rust-version.workspace = true
arbitrary = { workspace = true, optional = true }
derive-where = { workspace = true, features = ["serde"] }
ordered-float.workspace = true
salsa = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
# TODO: test with feature power set
[features]
fuzz = ["dep:arbitrary", "ordered-float/arbitrary"]
salsa = ["dep:salsa"]
serde = ["dep:serde"]

View File

@@ -31,13 +31,16 @@ use ordered_float::OrderedFloat;
#[cfg(feature = "fuzz")]
use arbitrary::Arbitrary;
#[cfg(feature = "salsa")]
use salsa::Update;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
// TODO: write more top-level docs
// TODO: reintroduce validation logic
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -52,7 +55,7 @@ pub struct Program<T: ProgramInfo> {
pub symbols: Vec<T::SymbolLabel>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -79,7 +82,7 @@ pub struct Relation<T: ProgramInfo> {
pub rules: Vec<Rule<T>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -95,7 +98,7 @@ pub struct Rule<T: ProgramInfo> {
pub body: RuleBody<T>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -111,7 +114,7 @@ pub struct Assumption<T: ProgramInfo> {
pub body: RuleBody<T>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -132,7 +135,7 @@ pub struct RuleBody<T: ProgramInfo> {
pub body: Expr<T>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -142,7 +145,7 @@ pub struct Expr<T: ProgramInfo> {
pub kind: ExprKind<T>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive_where(Clone, Debug, PartialEq, Eq)]
#[derive_where(PartialOrd, Ord; T: OrdProgramInfo)]
#[derive_where(Hash; T: HashProgramInfo)]
#[cfg_attr(feature = "fuzz", derive_where(Arbitrary; T: ArbitraryProgramInfo))]
@@ -314,7 +317,7 @@ impl Value {
/// Associated types specializing a [Program] with stage-specific or format-specific information.
// TODO: document how each type's usage changes based on stage and format
pub trait ProgramInfo: Info {
pub trait ProgramInfo {
/// The type of global symbol labels.
type SymbolLabel: Info;
@@ -339,6 +342,8 @@ pub trait ProgramInfo: Info {
/// A blanket trait bounding all associated types in [ProgramInfo].
pub trait Info: Clone + Debug + Eq {}
impl<T: Clone + Debug + Eq> Info for T {}
/// Creates a new trait to further bound the members of [ProgramInfo].
macro_rules! def_bound_info {
($name:ident, $bounds:tt) => {
@@ -375,6 +380,32 @@ def_bound_info!(HashProgramInfo, Hash);
#[cfg(feature = "fuzz")]
def_bound_info!(ArbitraryProgramInfo, (for<'a> Arbitrary<'a>));
#[cfg(feature = "salsa")]
macro_rules! impl_update {
() => {};
($head:ident, $($tail:ident),*) => {
impl_update!($head);
impl_update!($($tail),*);
};
($item:ident) => {
unsafe impl<T: ProgramInfo> Update for $item<T> {
unsafe fn maybe_update(old: *mut Self, new: Self) -> bool {
let old: &mut Self = unsafe { &mut *old };
if *old == new {
false
} else {
let _ = std::mem::replace(old, new);
true
}
}
}
};
}
#[cfg(feature = "salsa")]
impl_update!(Program, Relation, Rule, RuleBody, Assumption, Expr, ExprKind);
#[cfg(feature = "serde")]
def_bound_info!(DeserializeProgramInfo, (for<'a> Deserialize<'a>));