Skip to content

Commit

Permalink
port semantic pick
Browse files Browse the repository at this point in the history
  • Loading branch information
jimmytyyang committed Apr 30, 2024
1 parent a3eeee9 commit d3f0ba6
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 1 deletion.
8 changes: 8 additions & 0 deletions spot_rl_experiments/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ WEIGHTS_TORCHSCRIPT:
# Mobile Gaze torchscript module files path
# MOBILE_GAZE: "weights/torchscript/mg97_2_latest_combined_net.torchscript"
MOBILE_GAZE: "weights/mobile_gaze_v2/mg97hv1_4_ckpt.16.torchscript"
SEMANTIC_GAZE: "weights/mobile_gaze_v2/mg97h103_15_ckpt.35.torchscript"

# Static place
PLACE: "weights/torchscript/place_10deg_32_seed300_1649709235_ckpt.75_combined_net.torchscript"
Expand All @@ -30,6 +31,7 @@ WEIGHTS:
# Mobile Gaze torchscript module files path
# MOBILE_GAZE: "weights/torchscript/mg97_2_latest_combined_net.torchscript"
MOBILE_GAZE: "weights/mobile_gaze_v2/mg97hv1_4_ckpt.16.torchscript"
SEMANTIC_GAZE: "weights/mobile_gaze_v2/mg97h103_15_ckpt.35.torchscript"

# Static place
PLACE: "weights/final_paper/place_10deg_32_seed300_1649709235_ckpt.75.pth"
Expand Down Expand Up @@ -78,6 +80,12 @@ MAX_JOINT_MOVEMENT_MOBILE_GAZE: 0.04 # radian
MOBILE_GAZE_ACTION_SPACE_LENGTH: 7
HEURISTIC_SEARCH_ANGLE_INTERVAL: 20

# Semantic Gaze env
MAX_LIN_DIST_SEMANTIC_GAZE: 0
MAX_ANG_DIST_SEMANTIC_GAZE: 0 # degrees
MAX_JOINT_MOVEMENT_SEMANTIC_GAZE: 0.04 # radian
SEMANTIC_GAZE_ACTION_SPACE_LENGTH: 7

# Place env
EE_GRIPPER_OFFSET: [0.2, 0.0, 0.05]
SUCC_XY_DIST: 0.25
Expand Down
20 changes: 20 additions & 0 deletions spot_rl_experiments/experiments/skill_test/test_semantic_gaze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import numpy as np
from spot_rl.envs.skill_manager import SpotSkillManager

if __name__ == "__main__":
from perception_and_utils.utils.generic_utils import map_user_input_to_boolean

spotskillmanager = SpotSkillManager(use_mobile_pick=True)
contnue = True
while contnue:
spotskillmanager.semanticpick("penguin", "topdown")
spotskillmanager.spot.open_gripper()
contnue = map_user_input_to_boolean("Do you want to do it again ? Y/N ")

# Navigate to dock and shutdown
spotskillmanager.dock()
106 changes: 106 additions & 0 deletions spot_rl_experiments/spot_rl/envs/gaze_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,109 @@ def get_observations(self):

def get_success(self, observations):
return self.grasp_attempted


class SpotSemanticGazeEnv(SpotBaseEnv):
def __init__(self, config, spot: Spot):
# Select suitable keys
max_joint_movement_key = "MAX_JOINT_MOVEMENT_SEMANTIC_GAZE"
max_lin_dist_key = "MAX_LIN_DIST_SEMANTIC_GAZE"
max_ang_dist_key = "MAX_ANG_DIST_SEMANTIC_GAZE"

super().__init__(
config,
spot,
stopwatch=None,
max_joint_movement_key=max_joint_movement_key,
max_lin_dist_key=max_lin_dist_key,
max_ang_dist_key=max_ang_dist_key,
)
self.target_obj_name = None
self.initial_arm_joint_angles = np.deg2rad(config.GAZE_ARM_JOINT_ANGLES)
self.grasping_type = "topdown"

def reset(self, target_obj_name, grasping_type, *args, **kwargs):
# Move arm to initial configuration
cmd_id = self.spot.set_arm_joint_positions(
positions=self.initial_arm_joint_angles, travel_time=1
)

# Block until arm arrives with incremental timeout for 3 attempts
timeout_sec = 1.0
max_allowed_timeout_sec = 3.0
status = False
while status is False and timeout_sec <= max_allowed_timeout_sec:
status = self.spot.block_until_arm_arrives(cmd_id, timeout_sec=timeout_sec)
timeout_sec += 1.0

print("Open gripper called in Gaze")
self.spot.open_gripper()

# Update target object name as provided in config
observations = super().reset(target_obj_name=target_obj_name, *args, **kwargs)
rospy.set_param("object_target", target_obj_name)
rospy.set_param("is_gripper_blocked", 0)
self.grasping_type = grasping_type
return observations

def step(self, action_dict: Dict[str, Any]):
grasp = self.should_grasp()

# Update the action_dict with grasp and place flags
action_dict["grasp"] = grasp
action_dict["place"] = False # TODO: Why is gaze getting flag for place?

observations, reward, done, info = super().step(
action_dict=action_dict,
)
return observations, reward, done, info

def remap_observation_keys_for_hab3(self, observations):
"""
Change observation keys as per hab3.
@INFO: Policies trained on older hab versions DON'T need remapping
"""
semantic_gaze_observations = {}
semantic_gaze_observations["arm_depth_bbox_sensor"] = observations[
"arm_depth_bbox"
]
semantic_gaze_observations["articulated_agent_arm_depth"] = observations[
"arm_depth"
]
semantic_gaze_observations["joint"] = observations["joint"]
return semantic_gaze_observations

def get_observations(self):
arm_depth, arm_depth_bbox = self.get_gripper_images()
observations = {
"joint": self.get_arm_joints(),
"arm_depth": arm_depth,
"arm_depth_bbox": arm_depth_bbox,
}
# Remap observation keys for mobile gaze as it was trained with Habitat version3
observations = self.remap_observation_keys_for_hab3(observations)

# Get the observation for top down or side grasping
# Get base to hand's transformation
ee_T = self.spot.get_magnum_Matrix4_spot_a_T_b("vision", "hand")
# Get the base transformation
base_T = self.spot.get_magnum_Matrix4_spot_a_T_b("vision", "body")
base_to_ee_T = base_T.inverted() @ ee_T
target_vector = np.array([0, 0, 1.0])
# Get the direction vector
dir_vector = np.array(base_to_ee_T.transform_vector(target_vector))

if self.grasping_type == "topdown":
delta = 1.0 - abs(dir_vector[0])
elif self.grasping_type == "side":
delta = abs(dir_vector[0])
print(f"delta {delta} {self.grasping_type} {dir_vector}")
observations["topdown_or_side_grasping"] = np.array(
[delta],
dtype=np.float32,
)
return observations

def get_success(self, observations):
return self.grasp_attempted
33 changes: 33 additions & 0 deletions spot_rl_experiments/spot_rl/envs/skill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OpenCloseDrawer,
Pick,
Place,
SemanticPick,
SemanticPlace,
)
from spot_rl.utils.construct_configs import (
Expand Down Expand Up @@ -215,6 +216,10 @@ def __initiate_controllers(self, use_policies: bool = True):
spot=self.spot,
config=self.open_close_drawer_config,
)
self.semantic_gaze_controller = SemanticPick(
spot=self.spot,
config=self.pick_config,
)

def reset(self):
# Reset the policies and environments via the controllers
Expand Down Expand Up @@ -340,6 +345,34 @@ def pick(self, target_obj_name: str = None) -> Tuple[bool, str]:
conditional_print(message=message, verbose=self.verbose)
return status, message

def semanticpick(
self, target_obj_name: str = None, grasping_type: str = "topdown"
) -> Tuple[bool, str]:
"""
Perform the semantic pick action on the pick target specified as string
Args:
target_obj_name (str): Descriptive name of the pick target (eg: ball_plush)
grasping_type (str): The grasping type
Returns:
bool: True if pick was successful, False otherwise
str: Message indicating the status of the pick
"""
assert grasping_type in [
"topdown",
"side",
], f"Do not support {grasping_type} grasping"

goal_dict = {
"target_object": target_obj_name,
"take_user_input": False,
"grasping_type": grasping_type,
} # type: Dict[str, Any]
status, message = self.semantic_gaze_controller.execute(goal_dict=goal_dict)
conditional_print(message=message, verbose=self.verbose)
return status, message

@multimethod # type: ignore
def place(self, place_target: str = None, ee_orientation_at_grasping: np.ndarray = None, is_local: bool = False, visualize: bool = False) -> Tuple[bool, str]: # type: ignore
"""
Expand Down
35 changes: 35 additions & 0 deletions spot_rl_experiments/spot_rl/real_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,41 @@ def __init__(self, checkpoint_path, device, config: CN = CN()):
)


class SemanticGazePolicy(RealPolicy):
def __init__(self, checkpoint_path, device, config: CN = CN()):
observation_space = SpaceDict(
{
"arm_depth_bbox_sensor": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(240, 228, 1),
dtype=np.float32,
),
"articulated_agent_arm_depth": spaces.Box(
low=0.0, high=1.0, shape=(240, 228, 1), dtype=np.float32
),
"topdown_or_side_grasping": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(1,),
dtype=np.float32,
),
"joint": spaces.Box(
low=np.finfo(np.float32).min,
high=np.finfo(np.float32).max,
shape=(4,),
dtype=np.float32,
),
}
)
action_space = spaces.Box(
-1.0, 1.0, (config.get("MOBILE_GAZE_ACTION_SPACE_LENGTH", 7),)
)
super().__init__(
checkpoint_path, observation_space, action_space, device, config=config
)


class PlacePolicy(RealPolicy):
def __init__(self, checkpoint_path, device, config: CN = CN()):
observation_space = SpaceDict(
Expand Down
102 changes: 101 additions & 1 deletion spot_rl_experiments/spot_rl/skills/atomic_skills.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
conditional_print,
map_user_input_to_boolean,
)
from spot_rl.envs.gaze_env import SpotGazeEnv
from spot_rl.envs.gaze_env import SpotGazeEnv, SpotSemanticGazeEnv

# Import Envs
from spot_rl.envs.nav_env import SpotNavEnv
Expand All @@ -26,6 +26,7 @@
NavPolicy,
OpenCloseDrawerPolicy,
PlacePolicy,
SemanticGazePolicy,
SemanticPlacePolicy,
)

Expand Down Expand Up @@ -542,6 +543,105 @@ def split_action(self, action: np.ndarray) -> Dict[str, Any]:
return action_dict


class SemanticPick(Pick):
"""
Semantic Pick is used to gaze at, and pick given objects.
CAUTION: The robot will drop the object after picking it, please use objects that are not fragile
Expected goal_dict input:
goal_dict = {
"target_object": "apple", # (Necessary) Name of the target object to pick
"take_user_input": False, # (Optional) Whether to take user input for verifying success of the gaze
}
Args:
spot (Spot): Spot object
config (Config): Config object
How to use:
1. Create a Pick object
2. Call execute(goal_dict) method with "target"object" as a str in input goal_dict
3. Call get_most_recent_result_log() to get the result from the most recent pick operation
Example:
config = construct_config_for_gaze(opts=[])
spot = Spot("spot_client_name")
with spot.get_lease(hijack=True):
spot.power_robot()
gaze_target_list = ["apple", "banana"]
results = []
pick = Pick(spot, config)
for target_object in gaze_target_list:
goal_dict = {"target_object": target_object}
status, feedback = pick.execute(goal_dict=goal_dict)
results.append(pick.get_most_recent_result_log())
spot.shutdown(should_dock=True)
"""

def __init__(self, spot, config=None) -> None:
if not config:
config = construct_config_for_gaze()
super().__init__(spot, config)

self.policy = SemanticGazePolicy(
self.config.WEIGHTS.SEMANTIC_GAZE,
device=self.config.DEVICE,
config=self.config,
)

self.policy.reset()

self.env = SpotSemanticGazeEnv(self.config, spot)

def reset_skill(self, goal_dict: Dict[str, Any]) -> Any:
"""Refer to class Skill for documentation"""
try:
self.sanity_check(goal_dict)
except Exception as e:
raise e

target_obj_name = goal_dict.get("target_object", None)
take_user_input = goal_dict.get("take_user_input", False) # type: bool
grasping_type = goal_dict.get("grasping_type", "topdown")
conditional_print(
message=f"Gaze at object : {target_obj_name} - {'WILL' if take_user_input else 'WILL NOT'} take user input at the end for verification of pick",
verbose=self.verbose,
)

# Reset the env and policy
observations = self.env.reset(
target_obj_name=target_obj_name, grasping_type=grasping_type
)
self.policy.reset()

# Logging and Debug
self.env.say(
f"Gaze at target object - {target_obj_name} with {grasping_type} grasping"
)
print(
"The robot will drop the object after picking it, please use objects that are not fragile"
)

# Reset logged data at init
self.reset_logger()

return observations

def split_action(self, action: np.ndarray) -> Dict[str, Any]:
"""Refer to class Skill for documentation"""
# Mobile pick uses both base & arm but static pick only uses arm
action_dict = None

# first 4 are arm actions, then 2 are base actions & last bit is unused
action_dict = {
"arm_action": action[0:4],
"base_action": action[4:6],
}
return action_dict


class Place(Skill):
"""
Place controller is used to execute place for given place targets
Expand Down
16 changes: 16 additions & 0 deletions spot_rl_experiments/utils/cparamsssemanticmobilegaze.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Define the model class name
MODEL_CLASS_NAME : "habitat_baselines.rl.ddppo.policy.resnet_policy.PointNavResNetPolicy"
# Define the observation dict
OBSERVATIONS_DICT:
arm_depth_bbox_sensor: [[240, 228, 1], 'np.finfo(np.float32).min', 'np.finfo(np.float32).max', 'np.float32']
articulated_agent_arm_depth: [[240, 228, 1], '0.0', '1.0', 'np.float32']
topdown_or_side_grasping: [[1,], 'np.finfo(np.float32).min', 'np.finfo(np.float32).max', 'np.float32']
joint: [[4,], 'np.finfo(np.float32).min', 'np.finfo(np.float32).max', 'np.float32']
# Define the action space output length
ACTION_SPACE_LENGTH: 7
# The path to load and save the models
TARGET_HAB3_POLICY_PATH: "../weights/mobile_gaze_v2/mg97h103_15_ckpt.35.pth"
OUTPUT_COMBINED_NET_SAVE_PATH: "../weights/mobile_gaze_v2/mg97h103_15_ckpt.35.torchscript"
# If we want to use stereo pair camera for mobile gaze
USE_STEREO_PAIR_CAMERA: False
NEW_HABITAT_LAB_POLICY_OR_OLD: 'new'

0 comments on commit d3f0ba6

Please sign in to comment.