aboutsummaryrefslogtreecommitdiff
path: root/lib/templating.py
blob: 1fe34658bfa9ecd55cf186d3b613bfd356128ff1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/python
#
# Copyright 2011 Nick Mathewson, Michael Stone
#
#  You may do anything with this work that copyright law would normally
#  restrict, so long as you retain the above notice(s) and this license
#  in all redistributed copies and derived works.  There is no warranty.

"""
>>> base = Environ(foo=99, bar=600)
>>> derived1 = Environ(parent=base, bar=700, quux=32)
>>> base["foo"]
99
>>> sorted(base.keys())
['bar', 'foo']
>>> derived1["foo"]
99
>>> base["bar"]
600
>>> derived1["bar"]
700
>>> derived1["quux"]
32
>>> sorted(derived1.keys())
['bar', 'foo', 'quux']
>>> class Specialized(Environ):
...    def __init__(self, p=None, **kw):
...       Environ.__init__(self, p, **kw)
...       self._n_calls = 0
...    def _get_expensive_value(self, me):
...       self._n_calls += 1
...       return "Let's pretend this is hard to compute"
...
>>> s = Specialized(base, quux="hi")
>>> s["quux"]
'hi'
>>> s['expensive_value']
"Let's pretend this is hard to compute"
>>> s['expensive_value']
"Let's pretend this is hard to compute"
>>> s._n_calls
1
>>> sorted(s.keys())
['bar', 'expensive_value', 'foo', 'quux']

>>> bt = _BetterTemplate("Testing ${hello}, $goodbye$$, $foo , ${a:b:c}")
>>> bt.safe_substitute({'a:b:c': "4"}, hello=1, goodbye=2, foo=3)
'Testing 1, 2$, 3 , 4'

>>> t = Template("${include:/dev/null} $hi_there")
>>> sorted(t.freevars())
['hi_there']
>>> t.format(dict(hi_there=99))
' 99'
>>> t2 = Template("X$${include:$fname} $bar $baz")
>>> t2.format(dict(fname="/dev/null", bar=33, baz="$foo", foo=1337))
'X 33 1337'
>>> sorted(t2.freevars({'fname':"/dev/null"}))
['bar', 'baz', 'fname']

"""

from __future__ import with_statement

import string
import os
import re

#class _KeyError(KeyError):
#    pass

_KeyError = KeyError

class _DictWrapper:
    def __init__(self, parent=None):
        self._parent = parent

    def __getitem__(self, key):
        try:
            return self._getitem(key)
        except KeyError:
            pass
        if self._parent is None:
            raise _KeyError(key)

        try:
            return self._parent[key]
        except KeyError:
            raise _KeyError(key)

class Environ(_DictWrapper):
    def __init__(self, parent=None, **kw):
        _DictWrapper.__init__(self, parent)
        self._dict = kw
        self._cache = {}

    def _getitem(self, key):
        try:
            return self._dict[key]
        except KeyError:
            pass
        try:
            return self._cache[key]
        except KeyError:
            pass
        fn = getattr(self, "_get_%s"%key, None)
        if fn is not None:
            try:
                self._cache[key] = rv = fn(self)
                return rv
            except _KeyError:
                raise KeyError(key)
        raise KeyError(key)

    def __setitem__(self, key, val):
        self._dict[key] = val

    def keys(self):
        s = set()
        s.update(self._dict.keys())
        s.update(self._cache.keys())
        if self._parent is not None:
            s.update(self._parent.keys())
        s.update(name[5:] for name in dir(self) if name.startswith("_get_"))
        return s

class IncluderDict(_DictWrapper):
    def __init__(self, parent, includePath=(".",)):
        _DictWrapper.__init__(self, parent)
        self._includePath = includePath

    def _getitem(self, key):
        if not key.startswith("include:"):
            raise KeyError(key)

        filename = key[len("include:"):]
        if os.path.isabs(filename):
            with open(filename, 'r') as f:
                return f.read()

        for elt in self._includePath:
            fullname = os.path.join(elt, filename)
            if os.path.exists(fullname):
                with open(fullname, 'r') as f:
                    return f.read()

        raise KeyError(key)

class _BetterTemplate(string.Template):

    idpattern = r'[a-z0-9:_/\.\-]+'

    def __init__(self, template):
        string.Template.__init__(self, template)

class _FindVarsHelper:
    def __init__(self, dflts):
        self._dflts = dflts
        self._vars = set()
    def __getitem__(self, var):
        self._vars.add(var)
        try:
            return self._dflts[var]
        except KeyError:
            return ""

class Template:
    MAX_ITERATIONS = 32

    def __init__(self, pattern, includePath=(".",)):
        self._pat = pattern
        self._includePath = includePath

    def freevars(self, defaults=None):
        if defaults is None:
            defaults = {}
        d = _FindVarsHelper(defaults)
        self.format(d)
        return d._vars

    def format(self, values):
        values = IncluderDict(values, self._includePath)
        orig_val = self._pat
        nIterations = 0
        while True:
            v = _BetterTemplate(orig_val).substitute(values)
            if v == orig_val:
                return v
            orig_val = v
            nIterations += 1
            if nIterations > self.MAX_ITERATIONS:
                raise ValueError("Too many iterations in expanding template!")

if __name__ == '__main__':
    import sys
    if len(sys.argv) == 1:
        import doctest
        doctest.testmod()
        print "done"
    else:
        for fn in sys.argv[1:]:
            with open(fn, 'r') as f:
                t = Template(f.read())
                print fn, t.freevars()