diff --git a/mshv-bindings/src/regs.rs b/mshv-bindings/src/regs.rs index fad84ad1..d0f76842 100644 --- a/mshv-bindings/src/regs.rs +++ b/mshv-bindings/src/regs.rs @@ -475,6 +475,13 @@ impl Buffer { pub fn size(&self) -> usize { self.layout.size() } + + pub fn zero_out_buf(&mut self) { + // SAFETY: We write zeros to a valid pointer and the size is valid and allocated from a valid layout. + unsafe { + ::std::ptr::write_bytes(self.buf, 0u8, self.size()); + } + } } impl Drop for Buffer { @@ -673,3 +680,122 @@ pub struct SuspendRegisters { pub struct MiscRegs { pub hypercall: u64, } + +// Total size: 13512 bytes +// 1. MSHV_VP_STATE_LAPIC, Size: 1024 bytes; +// 2. MSHV_VP_STATE_XSAVE, Size: 4096 bytes; +// 3. MSHV_VP_STATE_SIMP, Size: 4096 bytes; +// 4. MSHV_VP_STATE_SIEFP, Size: 4096 bytes; +// 5. MSHV_VP_STATE_SYNTHETIC_TIMERS, Size: 200 bytes; +const VP_STATE_COMP_SIZES: [usize; 5] = [0x400, 0x1000, 0x1000, 0x1000, 0xC8]; + +pub const VP_STATE_COMPONENTS_BUFFER_SIZE: usize = VP_STATE_COMP_SIZES + [MSHV_VP_STATE_LAPIC as usize] + + VP_STATE_COMP_SIZES[MSHV_VP_STATE_XSAVE as usize] + + VP_STATE_COMP_SIZES[MSHV_VP_STATE_SIMP as usize] + + VP_STATE_COMP_SIZES[MSHV_VP_STATE_SIEFP as usize] + + VP_STATE_COMP_SIZES[MSHV_VP_STATE_SYNTHETIC_TIMERS as usize]; + +#[inline(always)] +fn get_vp_state_comp_start_offset(index: usize) -> usize { + VP_STATE_COMP_SIZES[0..index].iter().copied().sum() +} + +// Total five components are stored in a single buffer serially +// Components are: +// Local APIC, Xsave, Synthetic Message Page, Synthetic Event Flags Page +// and Synthetic Timers. +#[repr(C)] +#[derive(Copy, Clone, Debug, AsBytes, FromBytes, FromZeroes)] +/// Fixed buffer for VP state components +pub struct AllVpStateComponents { + pub buffer: [u8; VP_STATE_COMPONENTS_BUFFER_SIZE], +} + +impl Default for AllVpStateComponents { + fn default() -> Self { + unsafe { ::std::mem::zeroed() } + } +} + +impl AllVpStateComponents { + pub fn copy_to_or_from_buffer(&mut self, index: usize, buffer: &mut Buffer, to_buffer: bool) { + let len: usize = VP_STATE_COMP_SIZES[index]; + + if len > buffer.size() { + panic!("Invalid buffer length for state components"); + } + + let start = get_vp_state_comp_start_offset(index); + let end = start + len; + + if to_buffer { + // SAFETY: buffer is large enough to hold state data + unsafe { ptr::copy(self.buffer[start..end].as_ptr(), buffer.buf, len) }; + } else { + // SAFETY: buffer is large enough to hold state data + unsafe { ptr::copy(buffer.buf, self.buffer[start..end].as_mut_ptr(), len) }; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::slice::from_raw_parts_mut; + + #[test] + fn test_all_vp_state_components_copy_to_buffer() { + let mut states: AllVpStateComponents = AllVpStateComponents::default(); + let mut buffer = Buffer::new(HV_PAGE_SIZE, HV_PAGE_SIZE).unwrap(); + + for i in 0..VP_STATE_COMPONENTS_BUFFER_SIZE { + states.buffer[i] = 0xB9; + } + + //test copy to buffer + for i in 0..MSHV_VP_STATE_COUNT { + let len = VP_STATE_COMP_SIZES[i as usize]; + let start = get_vp_state_comp_start_offset(i as usize); + let end = start + len; + states.copy_to_or_from_buffer(i as usize, &mut buffer, true); + // SAFETY: We read less than or equal to buffer length and the slice is valid. + let buf_arr = unsafe { std::slice::from_raw_parts(buffer.buf, len) }; + assert!(states.buffer[start..end] + .iter() + .zip(buf_arr) + .all(|(a, b)| a == b)); + } + } + + #[test] + fn test_all_vp_state_components_copy_from_buffer() { + let mut states: AllVpStateComponents = AllVpStateComponents::default(); + let buffer = Buffer::new(HV_PAGE_SIZE, HV_PAGE_SIZE).unwrap(); + let mut copy_buffer = Buffer::new(HV_PAGE_SIZE, HV_PAGE_SIZE).unwrap(); + + // SAFETY: Safe because the entire buffer is accessible as bytes, + // modifying them in the form of a byte slice is valid + let mut_buf = unsafe { from_raw_parts_mut(buffer.buf, buffer.layout.size()) }; + for itm in mut_buf.iter_mut().take(HV_PAGE_SIZE) { + *itm = 0xA5; + } + + // SAFETY: buffer is large enough to hold state data + unsafe { ptr::copy(mut_buf.as_mut_ptr(), copy_buffer.buf, HV_PAGE_SIZE) }; + + //test copy to buffer + for i in 0..MSHV_VP_STATE_COUNT { + let len = VP_STATE_COMP_SIZES[i as usize]; + let start = get_vp_state_comp_start_offset(i as usize); + let end = start + len; + + states.copy_to_or_from_buffer(i as usize, &mut copy_buffer, false); + let buf_arr = &mut_buf[0..len]; + assert!(states.buffer[start..end] + .iter() + .zip(buf_arr) + .all(|(a, b)| a == b)); + } + } +} diff --git a/mshv-bindings/src/serializers.rs b/mshv-bindings/src/serializers.rs index fe273b33..4979e328 100644 --- a/mshv-bindings/src/serializers.rs +++ b/mshv-bindings/src/serializers.rs @@ -53,6 +53,30 @@ impl Serialize for XSave { data_buffer.serialize(serializer) } } + +impl Serialize for AllVpStateComponents { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let data_buffer = &self.buffer[..]; + data_buffer.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for AllVpStateComponents { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + let data_buffer: Vec = Vec::deserialize(deserializer)?; + let mut val = AllVpStateComponents::default(); + // This panics if the source and destination have different lengths. + val.buffer.copy_from_slice(&data_buffer[..]); + Ok(val) + } +} + #[cfg(test)] mod tests { use super::*; @@ -88,4 +112,21 @@ mod tests { .zip(d_xsave.buffer.iter()) .all(|(a, b)| a == b)); } + + #[test] + fn test_vp_state_components_serialization_deserialization() { + let mut states = AllVpStateComponents { + ..Default::default() + }; + for i in 0..VP_STATE_COMPONENTS_BUFFER_SIZE { + states.buffer[i] = 0xC8; + } + let serialized = serde_json::to_string(&states).expect("err ser"); + let d_states: AllVpStateComponents = serde_json::from_str(&serialized).expect("err unser"); + assert!(states + .buffer + .iter() + .zip(d_states.buffer.iter()) + .all(|(a, b)| a == b)); + } } diff --git a/mshv-ioctls/src/ioctls/vcpu.rs b/mshv-ioctls/src/ioctls/vcpu.rs index ed69768c..114b9754 100644 --- a/mshv-ioctls/src/ioctls/vcpu.rs +++ b/mshv-ioctls/src/ioctls/vcpu.rs @@ -1076,6 +1076,45 @@ impl VcpuFd { self.set_reg(®_assocs)?; Ok(()) } + + /// Gets the VP state components + #[cfg(not(target_arch = "aarch64"))] + pub fn get_all_vp_state_components(&self) -> Result { + let mut states: AllVpStateComponents = AllVpStateComponents::default(); + let mut buffer = Buffer::new(HV_PAGE_SIZE, HV_PAGE_SIZE)?; + + for i in 0..MSHV_VP_STATE_COUNT { + buffer.zero_out_buf(); + let mut vp_state = mshv_get_set_vp_state { + buf_ptr: buffer.buf as u64, + buf_sz: buffer.size() as u32, + type_: i as u8, + ..Default::default() + }; + self.get_vp_state_ioctl(&mut vp_state)?; + states.copy_to_or_from_buffer(i as usize, &mut buffer, false); + } + Ok(states) + } + + /// Sets the VP state components + #[cfg(not(target_arch = "aarch64"))] + pub fn set_all_vp_state_components(&self, states: &mut AllVpStateComponents) -> Result<()> { + let mut buffer = Buffer::new(HV_PAGE_SIZE, HV_PAGE_SIZE)?; + + for i in 0..MSHV_VP_STATE_COUNT { + buffer.zero_out_buf(); + states.copy_to_or_from_buffer(i as usize, &mut buffer, true); + let vp_state = mshv_get_set_vp_state { + type_: i as u8, + buf_sz: buffer.size() as u32, + buf_ptr: buffer.buf as u64, + ..Default::default() + }; + self.set_vp_state_ioctl(&vp_state)?; + } + Ok(()) + } } #[allow(dead_code)] @@ -1495,4 +1534,19 @@ mod tests { let max_function = res[0]; assert!(max_function >= 1); } + + #[test] + fn test_get_set_vp_state_components() { + let hv = Mshv::new().unwrap(); + let vm = hv.create_vm().unwrap(); + let vcpu = vm.create_vcpu(0).unwrap(); + let mut states = vcpu.get_all_vp_state_components().unwrap(); + vcpu.set_all_vp_state_components(&mut states).unwrap(); + let ret_states = vcpu.get_all_vp_state_components().unwrap(); + assert!(states + .buffer + .iter() + .zip(ret_states.buffer) + .all(|(a, b)| *a == b)); + } }