diff --git a/src/project/manifest/environment.rs b/src/project/manifest/environment.rs index f92c60ee5..8582eadfa 100644 --- a/src/project/manifest/environment.rs +++ b/src/project/manifest/environment.rs @@ -1,4 +1,6 @@ use crate::consts; +use crate::utils::spanned::PixiSpanned; +use serde::{self, Deserialize, Deserializer}; /// The name of an environment. This is either a string or default for the default environment. #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] @@ -18,6 +20,18 @@ impl EnvironmentName { } } +impl<'de> Deserialize<'de> for EnvironmentName { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + match String::deserialize(deserializer)? { + name if name == consts::DEFAULT_ENVIRONMENT_NAME => Ok(EnvironmentName::Default), + name => Ok(EnvironmentName::Named(name)), + } + } +} + /// An environment describes a set of features that are available together. /// /// Individual features cannot be used directly, instead they are grouped together into @@ -33,10 +47,57 @@ pub struct Environment { /// environment. pub features: Vec, - /// The optional location of where the features are defined in the manifest toml. + /// The optional location of where the features of the environment are defined in the manifest toml. pub features_source_loc: Option>, /// An optional solver-group. Multiple environments can share the same solve-group. All the /// dependencies of the environment that share the same solve-group will be solved together. pub solve_group: Option, } + +/// Helper struct to deserialize the environment from TOML. +/// The environment description can only hold these values. +#[derive(Deserialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub(super) struct TomlEnvironment { + pub features: PixiSpanned>, + pub solve_group: Option, +} + +pub(super) enum TomlEnvironmentMapOrSeq { + Map(TomlEnvironment), + Seq(Vec), +} +impl TomlEnvironmentMapOrSeq { + pub fn into_environment(self, name: EnvironmentName) -> Environment { + match self { + TomlEnvironmentMapOrSeq::Map(TomlEnvironment { + features, + solve_group, + }) => Environment { + name, + features: features.value, + features_source_loc: features.span, + solve_group, + }, + TomlEnvironmentMapOrSeq::Seq(features) => Environment { + name, + features, + features_source_loc: None, + solve_group: None, + }, + } + } +} +impl<'de> Deserialize<'de> for TomlEnvironmentMapOrSeq { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + serde_untagged::UntaggedEnumVisitor::new() + .map(|map| map.deserialize().map(TomlEnvironmentMapOrSeq::Map)) + .seq(|seq| seq.deserialize().map(TomlEnvironmentMapOrSeq::Seq)) + .expecting("either a map or a sequence") + .deserialize(deserializer) + } +} diff --git a/src/project/manifest/feature.rs b/src/project/manifest/feature.rs index fd99f81c4..0ee1d06c2 100644 --- a/src/project/manifest/feature.rs +++ b/src/project/manifest/feature.rs @@ -1,9 +1,14 @@ -use super::SystemRequirements; +use super::{Activation, PyPiRequirement, SystemRequirements, Target, TargetSelector}; use crate::project::manifest::target::Targets; +use crate::project::SpecType; +use crate::task::Task; use crate::utils::spanned::PixiSpanned; -use rattler_conda_types::{Channel, Platform}; +use indexmap::IndexMap; +use rattler_conda_types::{Channel, NamelessMatchSpec, PackageName, Platform}; use serde::de::Error; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; +use serde_with::{serde_as, DisplayFromStr, PickFirst}; +use std::collections::HashMap; /// The name of a feature. This is either a string or default for the default feature. #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)] @@ -64,3 +69,72 @@ pub struct Feature { /// Target specific configuration. pub targets: Targets, } + +impl<'de> Deserialize<'de> for Feature { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[serde_as] + #[derive(Deserialize)] + #[serde(deny_unknown_fields, rename_all = "kebab-case")] + struct FeatureInner { + #[serde(default)] + platforms: Option>>, + #[serde_as(deserialize_as = "Option>")] + channels: Option>, + #[serde(default)] + system_requirements: SystemRequirements, + #[serde(default)] + target: IndexMap, Target>, + + #[serde(default)] + #[serde_as(as = "IndexMap<_, PickFirst<(DisplayFromStr, _)>>")] + dependencies: IndexMap, + + #[serde(default)] + #[serde_as(as = "Option>>")] + host_dependencies: Option>, + + #[serde(default)] + #[serde_as(as = "Option>>")] + build_dependencies: Option>, + + #[serde(default)] + pypi_dependencies: Option>, + + /// Additional information to activate an environment. + #[serde(default)] + activation: Option, + + /// Target specific tasks to run in the environment + #[serde(default)] + tasks: HashMap, + } + + let inner = FeatureInner::deserialize(deserializer)?; + + let mut dependencies = HashMap::from_iter([(SpecType::Run, inner.dependencies)]); + if let Some(host_deps) = inner.host_dependencies { + dependencies.insert(SpecType::Host, host_deps); + } + if let Some(build_deps) = inner.build_dependencies { + dependencies.insert(SpecType::Build, build_deps); + } + + let default_target = Target { + dependencies, + pypi_dependencies: inner.pypi_dependencies, + activation: inner.activation, + tasks: inner.tasks, + }; + + Ok(Feature { + name: FeatureName::Default, + platforms: inner.platforms, + channels: inner.channels, + system_requirements: inner.system_requirements, + targets: Targets::from_default_and_user_defined(default_target, inner.target), + }) + } +} diff --git a/src/project/manifest/mod.rs b/src/project/manifest/mod.rs index d8e73d780..56ac45cbc 100644 --- a/src/project/manifest/mod.rs +++ b/src/project/manifest/mod.rs @@ -9,6 +9,7 @@ mod system_requirements; mod target; mod validation; +use crate::project::manifest::environment::TomlEnvironmentMapOrSeq; use crate::{consts, project::SpecType, task::Task, utils::spanned::PixiSpanned}; use ::serde::{Deserialize, Deserializer}; pub use activation::Activation; @@ -22,7 +23,7 @@ pub use python::PyPiRequirement; use rattler_conda_types::{ Channel, ChannelConfig, MatchSpec, NamelessMatchSpec, PackageName, Platform, Version, }; -use serde_with::{serde_as, DisplayFromStr, PickFirst}; +use serde_with::{serde_as, DisplayFromStr, Map, PickFirst}; use std::{ collections::HashMap, path::{Path, PathBuf}, @@ -666,6 +667,15 @@ impl<'de> Deserialize<'de> for ProjectManifest { /// Target specific tasks to run in the environment #[serde(default)] tasks: HashMap, + + /// The features defined in the project. + #[serde(default)] + feature: IndexMap, + + /// The environments the project can create. + #[serde(default)] + #[serde_as(as = "Map<_, _>")] + environments: Vec<(EnvironmentName, TomlEnvironmentMapOrSeq)>, } let toml_manifest = TomlProjectManifest::deserialize(deserializer)?; @@ -708,10 +718,36 @@ impl<'de> Deserialize<'de> for ProjectManifest { solve_group: None, }; + // Construct the features including the default feature + let features: IndexMap = + IndexMap::from_iter([(FeatureName::Default, default_feature)]); + let named_features = toml_manifest + .feature + .into_iter() + .map(|(name, mut feature)| { + feature.name = name.clone(); + (name, feature) + }) + .collect::>(); + let features = features.into_iter().chain(named_features).collect(); + + // Construct the environments including the default environment + let environments: IndexMap = + IndexMap::from_iter([(EnvironmentName::Default, default_environment)]); + let named_environments = toml_manifest + .environments + .into_iter() + .map(|(name, t_env)| { + let env = t_env.into_environment(name.clone()); + (name, env) + }) + .collect::>(); + let environments = environments.into_iter().chain(named_environments).collect(); + Ok(Self { project: toml_manifest.project, - features: IndexMap::from_iter([(FeatureName::Default, default_feature)]), - environments: IndexMap::from_iter([(EnvironmentName::Default, default_environment)]), + features, + environments, }) } } @@ -1237,8 +1273,6 @@ ypackage = {version = ">=1.2.3"} version = "0.1.0" channels = [] platforms = ["linux-64", "win-64"] - - [dependencies] "#; let mut manifest = Manifest::from_str(Path::new(""), file_contents).unwrap(); @@ -1266,8 +1300,6 @@ ypackage = {version = ">=1.2.3"} description = "foo description" channels = [] platforms = ["linux-64", "win-64"] - - [dependencies] "#; let mut manifest = Manifest::from_str(Path::new(""), file_contents).unwrap(); @@ -1309,8 +1341,6 @@ ypackage = {version = ">=1.2.3"} description = "foo description" channels = [] platforms = ["linux-64", "win-64"] - - [dependencies] "#; let mut manifest = Manifest::from_str(Path::new(""), file_contents).unwrap(); @@ -1393,8 +1423,6 @@ ypackage = {version = ">=1.2.3"} description = "foo description" channels = ["conda-forge"] platforms = ["linux-64", "win-64"] - - [dependencies] "#; let mut manifest = Manifest::from_str(Path::new(""), file_contents).unwrap(); @@ -1408,4 +1436,194 @@ ypackage = {version = ">=1.2.3"} assert_eq!(manifest.parsed.project.channels, vec![]); } + + #[test] + fn test_environments_definition() { + let file_contents = r#" + [project] + name = "foo" + version = "0.1.0" + channels = ["conda-forge"] + platforms = ["linux-64", "win-64"] + + [feature.py39.dependencies] + python = "~=3.9.0" + + [feature.py310.dependencies] + python = "~=3.10.0" + + [feature.cuda.dependencies] + cudatoolkit = ">=11.0,<12.0" + + [feature.test.dependencies] + pytest = "*" + + [environments] + default = ["py39"] + cuda = ["cuda", "py310"] + test1 = {features = ["test", "py310"], solve-group = "test"} + test2 = {features = ["py39"], solve-group = "test"} + "#; + let manifest = Manifest::from_str(Path::new(""), file_contents).unwrap(); + let default_env = manifest.default_environment(); + assert_eq!(default_env.name, EnvironmentName::Default); + assert_eq!(default_env.features, vec!["py39"]); + + let cuda_env = manifest + .environment(&EnvironmentName::Named("cuda".to_string())) + .unwrap(); + assert_eq!(cuda_env.features, vec!["cuda", "py310"]); + assert_eq!(cuda_env.solve_group, None); + + let test1_env = manifest + .environment(&EnvironmentName::Named("test1".to_string())) + .unwrap(); + assert_eq!(test1_env.features, vec!["test", "py310"]); + assert_eq!(test1_env.solve_group, Some(String::from("test"))); + + let test2_env = manifest + .environment(&EnvironmentName::Named("test2".to_string())) + .unwrap(); + assert_eq!(test2_env.features, vec!["py39"]); + assert_eq!(test2_env.solve_group, Some(String::from("test"))); + } + + #[test] + fn test_feature_definition() { + let file_contents = r#" + [project] + name = "foo" + channels = [] + platforms = [] + + [feature.cuda] + dependencies = {cuda = "x.y.z", cudnn = "12.0"} + pypi-dependencies = {torch = "~=1.9.0"} + build-dependencies = {cmake = "*"} + platforms = ["linux-64", "osx-arm64"] + activation = {scripts = ["cuda_activation.sh"]} + system-requirements = {cuda = "12"} + channels = ["nvidia", "pytorch"] + tasks = { warmup = "python warmup.py" } + target.osx-arm64 = {dependencies = {mlx = "x.y.z"}} + + "#; + let manifest = Manifest::from_str(Path::new(""), file_contents).unwrap(); + + let cuda_feature = manifest + .parsed + .features + .get(&FeatureName::Named("cuda".to_string())) + .unwrap(); + assert_eq!(cuda_feature.name, FeatureName::Named("cuda".to_string())); + assert_eq!( + cuda_feature + .targets + .default() + .dependencies + .get(&SpecType::Run) + .unwrap() + .get(&PackageName::from_str("cuda").unwrap()) + .unwrap() + .to_string(), + "==x.y.z" + ); + assert_eq!( + cuda_feature + .targets + .default() + .dependencies + .get(&SpecType::Run) + .unwrap() + .get(&PackageName::from_str("cudnn").unwrap()) + .unwrap() + .to_string(), + "==12.0" + ); + assert_eq!( + cuda_feature + .targets + .default() + .pypi_dependencies + .as_ref() + .unwrap() + .get( + &rip::types::PackageName::from_str("torch") + .expect("torch should be a valid name") + ) + .expect("pypi requirement should be available") + .version + .clone() + .unwrap() + .to_string(), + "~=1.9.0" + ); + assert_eq!( + cuda_feature + .targets + .default() + .dependencies + .get(&SpecType::Build) + .unwrap() + .get(&PackageName::from_str("cmake").unwrap()) + .unwrap() + .to_string(), + "*" + ); + assert_eq!( + cuda_feature + .targets + .default() + .activation + .as_ref() + .unwrap() + .scripts + .as_ref() + .unwrap(), + &vec![String::from("cuda_activation.sh")] + ); + assert_eq!( + cuda_feature + .system_requirements + .cuda + .as_ref() + .unwrap() + .to_string(), + "12" + ); + assert_eq!( + cuda_feature + .channels + .as_ref() + .unwrap() + .iter() + .map(|c| c.name.clone().unwrap()) + .collect::>(), + vec!["nvidia", "pytorch"] + ); + assert_eq!( + cuda_feature + .targets + .for_target(&TargetSelector::Platform(Platform::OsxArm64)) + .unwrap() + .dependencies + .get(&SpecType::Run) + .unwrap() + .get(&PackageName::from_str("mlx").unwrap()) + .unwrap() + .to_string(), + "==x.y.z" + ); + assert_eq!( + cuda_feature + .targets + .default() + .tasks + .get("warmup") + .unwrap() + .as_single_command() + .unwrap(), + "python warmup.py" + ); + } } diff --git a/src/project/manifest/snapshots/pixi__project__manifest__test__invalid_key.snap b/src/project/manifest/snapshots/pixi__project__manifest__test__invalid_key.snap index 1f49cd965..3de48c4ca 100644 --- a/src/project/manifest/snapshots/pixi__project__manifest__test__invalid_key.snap +++ b/src/project/manifest/snapshots/pixi__project__manifest__test__invalid_key.snap @@ -6,7 +6,7 @@ TOML parse error at line 8, column 2 | 8 | [foobar] | ^^^^^^ -unknown field `foobar`, expected one of `project`, `system-requirements`, `target`, `dependencies`, `host-dependencies`, `build-dependencies`, `pypi-dependencies`, `activation`, `tasks` +unknown field `foobar`, expected one of `project`, `system-requirements`, `target`, `dependencies`, `host-dependencies`, `build-dependencies`, `pypi-dependencies`, `activation`, `tasks`, `feature`, `environments` TOML parse error at line 8, column 16 |