Skip to content

Commit c6f35fd

Browse files
committed
Implement get_disjoint_mut for arrays of keys
1 parent b56f035 commit c6f35fd

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

src/map.rs

+69
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,37 @@ where
790790
}
791791
}
792792

793+
/// Return the values for `N` keys. If any key is missing a value, or there
794+
/// are duplicate keys, `None` is returned.
795+
///
796+
/// # Examples
797+
///
798+
/// ```
799+
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
800+
/// assert_eq!(map.get_disjoint_mut([&2, &1]), Some([&mut 'c', &mut 'a']));
801+
/// ```
802+
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> Option<[&mut V; N]>
803+
where
804+
Q: Hash + Equivalent<K> + ?Sized,
805+
{
806+
let len = self.len();
807+
let indices = keys.map(|key| self.get_index_of(key));
808+
809+
// Handle out-of-bounds indices with panic as this is an internal error in get_index_of.
810+
for idx in indices {
811+
let idx = idx?;
812+
debug_assert!(
813+
idx < len,
814+
"Index is out of range! Got '{}' but length is '{}'",
815+
idx,
816+
len
817+
);
818+
}
819+
let indices = indices.map(Option::unwrap);
820+
let entries = self.get_disjoint_indices_mut(indices)?;
821+
Some(entries.map(|(_key, value)| value))
822+
}
823+
793824
/// Remove the key-value pair equivalent to `key` and return
794825
/// its value.
795826
///
@@ -1196,6 +1227,44 @@ impl<K, V, S> IndexMap<K, V, S> {
11961227
Some(IndexedEntry::new(&mut self.core, index))
11971228
}
11981229

1230+
/// Get an array of `N` key-value pairs by `N` indices
1231+
///
1232+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
1233+
///
1234+
/// Computes in **O(1)** time.
1235+
///
1236+
/// # Examples
1237+
///
1238+
/// ```
1239+
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
1240+
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Some([(&2, &mut 'c'), (&1, &mut 'a')]));
1241+
/// ```
1242+
pub fn get_disjoint_indices_mut<const N: usize>(
1243+
&mut self,
1244+
indices: [usize; N],
1245+
) -> Option<[(&K, &mut V); N]> {
1246+
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
1247+
let len = self.len();
1248+
for i in 0..N {
1249+
let idx = indices[i];
1250+
if idx >= len || indices[i + 1..N].contains(&idx) {
1251+
return None;
1252+
}
1253+
}
1254+
1255+
let entries_ptr = self.as_entries_mut().as_mut_ptr();
1256+
let out = indices.map(|i| {
1257+
// SAFETY: The base pointer is valid as it comes from a slice and the deref is always
1258+
// in-bounds as we've already checked the indices above.
1259+
#[allow(unsafe_code)]
1260+
unsafe {
1261+
(*(entries_ptr.add(i))).ref_mut()
1262+
}
1263+
});
1264+
1265+
Some(out)
1266+
}
1267+
11991268
/// Returns a slice of key-value pairs in the given range of indices.
12001269
///
12011270
/// Valid indices are `0 <= index < self.len()`.

src/map/tests.rs

+91
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,94 @@ move_index_oob!(test_move_index_out_of_bounds_0_10, 0, 10);
828828
move_index_oob!(test_move_index_out_of_bounds_0_max, 0, usize::MAX);
829829
move_index_oob!(test_move_index_out_of_bounds_10_0, 10, 0);
830830
move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0);
831+
832+
#[test]
833+
fn disjoint_mut_empty_map() {
834+
let mut map: IndexMap<u32, u32> = IndexMap::default();
835+
assert!(map.get_disjoint_mut([&0, &1, &2, &3]).is_none());
836+
}
837+
838+
#[test]
839+
fn disjoint_mut_empty_param() {
840+
let mut map: IndexMap<u32, u32> = IndexMap::default();
841+
map.insert(1, 10);
842+
assert!(map.get_disjoint_mut([] as [&u32; 0]).is_some());
843+
}
844+
845+
#[test]
846+
fn disjoint_mut_single_fail() {
847+
let mut map: IndexMap<u32, u32> = IndexMap::default();
848+
map.insert(1, 10);
849+
assert!(map.get_disjoint_mut([&0]).is_none());
850+
}
851+
852+
#[test]
853+
fn disjoint_mut_single_success() {
854+
let mut map: IndexMap<u32, u32> = IndexMap::default();
855+
map.insert(1, 10);
856+
assert_eq!(map.get_disjoint_mut([&1]), Some([&mut 10]));
857+
}
858+
859+
#[test]
860+
fn disjoint_mut_multi_success() {
861+
let mut map: IndexMap<u32, u32> = IndexMap::default();
862+
map.insert(1, 100);
863+
map.insert(2, 200);
864+
map.insert(3, 300);
865+
map.insert(4, 400);
866+
assert_eq!(map.get_disjoint_mut([&1, &2]), Some([&mut 100, &mut 200]));
867+
assert_eq!(map.get_disjoint_mut([&1, &3]), Some([&mut 100, &mut 300]));
868+
assert_eq!(
869+
map.get_disjoint_mut([&3, &1, &4, &2]),
870+
Some([&mut 300, &mut 100, &mut 400, &mut 200])
871+
);
872+
}
873+
874+
#[test]
875+
fn disjoint_mut_multi_success_unsized_key() {
876+
let mut map: IndexMap<&'static str, u32> = IndexMap::default();
877+
map.insert("1", 100);
878+
map.insert("2", 200);
879+
map.insert("3", 300);
880+
map.insert("4", 400);
881+
assert_eq!(map.get_disjoint_mut(["1", "2"]), Some([&mut 100, &mut 200]));
882+
assert_eq!(map.get_disjoint_mut(["1", "3"]), Some([&mut 100, &mut 300]));
883+
assert_eq!(
884+
map.get_disjoint_mut(["3", "1", "4", "2"]),
885+
Some([&mut 300, &mut 100, &mut 400, &mut 200])
886+
);
887+
}
888+
889+
#[test]
890+
fn disjoint_mut_multi_fail_missing() {
891+
let mut map: IndexMap<u32, u32> = IndexMap::default();
892+
map.insert(1, 10);
893+
map.insert(1123, 100);
894+
map.insert(321, 20);
895+
map.insert(1337, 30);
896+
assert_eq!(map.get_disjoint_mut([&121, &1123]), None);
897+
assert_eq!(map.get_disjoint_mut([&1, &1337, &56]), None);
898+
assert_eq!(map.get_disjoint_mut([&1337, &123, &321, &1, &1123]), None);
899+
}
900+
901+
#[test]
902+
fn disjoint_mut_multi_fail_duplicate() {
903+
let mut map: IndexMap<u32, u32> = IndexMap::default();
904+
map.insert(1, 10);
905+
map.insert(1123, 100);
906+
map.insert(321, 20);
907+
map.insert(1337, 30);
908+
assert_eq!(map.get_disjoint_mut([&1, &1]), None);
909+
assert_eq!(
910+
map.get_disjoint_mut([&1337, &123, &321, &1337, &1, &1123]),
911+
None
912+
);
913+
}
914+
915+
#[test]
916+
fn many_index_mut_fail_oob() {
917+
let mut map: IndexMap<u32, u32> = IndexMap::default();
918+
map.insert(1, 10);
919+
map.insert(321, 20);
920+
assert_eq!(map.get_disjoint_indices_mut([1, 3]), None);
921+
}

0 commit comments

Comments
 (0)