diff configuration/configuration.py @ 108:a2184db43fe2

fix dict command line processing
author Jeff Hammel <jhammel@mozilla.com>
date Thu, 03 May 2012 15:46:13 -0700
parents 83d66a9bdef0
children 09642528be02
line wrap: on
line diff
--- a/configuration/configuration.py	Thu May 03 08:26:10 2012 -0700
+++ b/configuration/configuration.py	Thu May 03 15:46:13 2012 -0700
@@ -111,6 +111,15 @@
     """option that keeps track if it is seen"""
     # TODO: this should be configurable or something
     def take_action(self, action, dest, opt, value, values, parser):
+
+        # switch on types
+        formatter = getattr(parser, 'cli_formatter')
+        if formatter:
+            formatter = formatter(dest)
+            if formatter:
+                value = formatter(value)
+
+        # call the optparse front-end
         optparse.Option.take_action(self, action, dest, opt, value, values, parser)
 
         # add the parsed option to the set of things parsed
@@ -118,13 +127,6 @@
             parser.parsed = set()
         parser.parsed.add(dest)
 
-        # switch on types
-        formatter = getattr(parser, 'cli_formatter')
-        if formatter:
-            formatter = formatter(dest)
-            if formatter:
-                setattr(values, dest, formatter(getattr(values, dest)))
-
 ### plugins for option types
 
 class BaseCLI(object):
@@ -202,11 +204,20 @@
 
     delimeter = '='
 
+    def __call__(self, name, value):
+
+        # optparse can't handle dict types OOTB
+        default = value.get('default')
+        if isinstance(default, dict):
+            value = copy.deepcopy(value)
+            value['default'] = default.items()
+
+        return ListCLI.__call__(self, name, value)
+
     def take_action(self, value):
-        bad = [i for i in value if self.delimeter not in i]
-        if bad:
-            raise AssertionError("Each value must be delimited by '%s': %s" % (self.delimeter, bad))
-        return dict([i.split(self.delimeter, 1) for i in value])
+        if self.delimeter not in value:
+            raise AssertionError("Each value must be delimited by '%s': %s" % (self.delimeter, value))
+        return value.split(self.delimeter, 1)
 
 # TODO: 'dict'-type cli interface
 
@@ -376,7 +387,7 @@
 
     def cli_formatter(self, option):
         if option in self.option_dict:
-            handler = self.types[self.option_dict[option].get('type')]
+            handler = self.types[self.option_type(option)]
             return getattr(handler, 'take_action', lambda x: x)
 
     def option_type(self, name):