Statistics
| Revision:

gvsig-scripting / org.gvsig.scripting / trunk / org.gvsig.scripting / org.gvsig.scripting.app / org.gvsig.scripting.app.mainplugin / src / main / resources-plugin / scripting / lib / dulwich / tests / test_protocol.py @ 959

History | View | Annotate | Download (10.8 KB)

1
# test_protocol.py -- Tests for the git protocol
2
# Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
3
#
4
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
5
# General Public License as public by the Free Software Foundation; version 2.0
6
# or (at your option) any later version. You can redistribute it and/or
7
# modify it under the terms of either of these two licenses.
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
#
15
# You should have received a copy of the licenses; if not, see
16
# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
17
# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
18
# License, Version 2.0.
19
#
20

    
21
"""Tests for the smart protocol utility functions."""
22

    
23

    
24
from io import BytesIO
25

    
26
from dulwich.errors import (
27
    HangupException,
28
    )
29
from dulwich.protocol import (
30
    GitProtocolError,
31
    PktLineParser,
32
    Protocol,
33
    ReceivableProtocol,
34
    extract_capabilities,
35
    extract_want_line_capabilities,
36
    ack_type,
37
    SINGLE_ACK,
38
    MULTI_ACK,
39
    MULTI_ACK_DETAILED,
40
    BufferedPktLineWriter,
41
    )
42
from dulwich.tests import TestCase
43

    
44

    
45
class BaseProtocolTests(object):
46

    
47
    def test_write_pkt_line_none(self):
48
        self.proto.write_pkt_line(None)
49
        self.assertEqual(self.rout.getvalue(), b'0000')
50

    
51
    def test_write_pkt_line(self):
52
        self.proto.write_pkt_line(b'bla')
53
        self.assertEqual(self.rout.getvalue(), b'0007bla')
54

    
55
    def test_read_pkt_line(self):
56
        self.rin.write(b'0008cmd ')
57
        self.rin.seek(0)
58
        self.assertEqual(b'cmd ', self.proto.read_pkt_line())
59

    
60
    def test_eof(self):
61
        self.rin.write(b'0000')
62
        self.rin.seek(0)
63
        self.assertFalse(self.proto.eof())
64
        self.assertEqual(None, self.proto.read_pkt_line())
65
        self.assertTrue(self.proto.eof())
66
        self.assertRaises(HangupException, self.proto.read_pkt_line)
67

    
68
    def test_unread_pkt_line(self):
69
        self.rin.write(b'0007foo0000')
70
        self.rin.seek(0)
71
        self.assertEqual(b'foo', self.proto.read_pkt_line())
72
        self.proto.unread_pkt_line(b'bar')
73
        self.assertEqual(b'bar', self.proto.read_pkt_line())
74
        self.assertEqual(None, self.proto.read_pkt_line())
75
        self.proto.unread_pkt_line(b'baz1')
76
        self.assertRaises(ValueError, self.proto.unread_pkt_line, b'baz2')
77

    
78
    def test_read_pkt_seq(self):
79
        self.rin.write(b'0008cmd 0005l0000')
80
        self.rin.seek(0)
81
        self.assertEqual([b'cmd ', b'l'], list(self.proto.read_pkt_seq()))
82

    
83
    def test_read_pkt_line_none(self):
84
        self.rin.write(b'0000')
85
        self.rin.seek(0)
86
        self.assertEqual(None, self.proto.read_pkt_line())
87

    
88
    def test_read_pkt_line_wrong_size(self):
89
        self.rin.write(b'0100too short')
90
        self.rin.seek(0)
91
        self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
92

    
93
    def test_write_sideband(self):
94
        self.proto.write_sideband(3, b'bloe')
95
        self.assertEqual(self.rout.getvalue(), b'0009\x03bloe')
96

    
97
    def test_send_cmd(self):
98
        self.proto.send_cmd(b'fetch', b'a', b'b')
99
        self.assertEqual(self.rout.getvalue(), b'000efetch a\x00b\x00')
100

    
101
    def test_read_cmd(self):
102
        self.rin.write(b'0012cmd arg1\x00arg2\x00')
103
        self.rin.seek(0)
104
        self.assertEqual((b'cmd', [b'arg1', b'arg2']), self.proto.read_cmd())
105

    
106
    def test_read_cmd_noend0(self):
107
        self.rin.write(b'0011cmd arg1\x00arg2')
108
        self.rin.seek(0)
109
        self.assertRaises(AssertionError, self.proto.read_cmd)
110

    
111

    
112
class ProtocolTests(BaseProtocolTests, TestCase):
113

    
114
    def setUp(self):
115
        TestCase.setUp(self)
116
        self.rout = BytesIO()
117
        self.rin = BytesIO()
118
        self.proto = Protocol(self.rin.read, self.rout.write)
119

    
120

    
121
class ReceivableBytesIO(BytesIO):
122
    """BytesIO with socket-like recv semantics for testing."""
123

    
124
    def __init__(self):
125
        BytesIO.__init__(self)
126
        self.allow_read_past_eof = False
127

    
128
    def recv(self, size):
129
        # fail fast if no bytes are available; in a real socket, this would
130
        # block forever
131
        if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
132
            raise GitProtocolError('Blocking read past end of socket')
133
        if size == 1:
134
            return self.read(1)
135
        # calls shouldn't return quite as much as asked for
136
        return self.read(size - 1)
137

    
138

    
139
class ReceivableProtocolTests(BaseProtocolTests, TestCase):
140

    
141
    def setUp(self):
142
        TestCase.setUp(self)
143
        self.rout = BytesIO()
144
        self.rin = ReceivableBytesIO()
145
        self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
146
        self.proto._rbufsize = 8
147

    
148
    def test_eof(self):
149
        # Allow blocking reads past EOF just for this test. The only parts of
150
        # the protocol that might check for EOF do not depend on the recv()
151
        # semantics anyway.
152
        self.rin.allow_read_past_eof = True
153
        BaseProtocolTests.test_eof(self)
154

    
155
    def test_recv(self):
156
        all_data = b'1234567' * 10  # not a multiple of bufsize
157
        self.rin.write(all_data)
158
        self.rin.seek(0)
159
        data = b''
160
        # We ask for 8 bytes each time and actually read 7, so it should take
161
        # exactly 10 iterations.
162
        for _ in range(10):
163
            data += self.proto.recv(10)
164
        # any more reads would block
165
        self.assertRaises(GitProtocolError, self.proto.recv, 10)
166
        self.assertEqual(all_data, data)
167

    
168
    def test_recv_read(self):
169
        all_data = b'1234567'  # recv exactly in one call
170
        self.rin.write(all_data)
171
        self.rin.seek(0)
172
        self.assertEqual(b'1234', self.proto.recv(4))
173
        self.assertEqual(b'567', self.proto.read(3))
174
        self.assertRaises(GitProtocolError, self.proto.recv, 10)
175

    
176
    def test_read_recv(self):
177
        all_data = b'12345678abcdefg'
178
        self.rin.write(all_data)
179
        self.rin.seek(0)
180
        self.assertEqual(b'1234', self.proto.read(4))
181
        self.assertEqual(b'5678abc', self.proto.recv(8))
182
        self.assertEqual(b'defg', self.proto.read(4))
183
        self.assertRaises(GitProtocolError, self.proto.recv, 10)
184

    
185
    def test_mixed(self):
186
        # arbitrary non-repeating string
187
        all_data = b','.join(str(i).encode('ascii') for i in range(100))
188
        self.rin.write(all_data)
189
        self.rin.seek(0)
190
        data = b''
191

    
192
        for i in range(1, 100):
193
            data += self.proto.recv(i)
194
            # if we get to the end, do a non-blocking read instead of blocking
195
            if len(data) + i > len(all_data):
196
                data += self.proto.recv(i)
197
                # ReceivableBytesIO leaves off the last byte unless we ask
198
                # nicely
199
                data += self.proto.recv(1)
200
                break
201
            else:
202
                data += self.proto.read(i)
203
        else:
204
            # didn't break, something must have gone wrong
205
            self.fail()
206

    
207
        self.assertEqual(all_data, data)
208

    
209

    
210
class CapabilitiesTestCase(TestCase):
211

    
212
    def test_plain(self):
213
        self.assertEqual((b'bla', []), extract_capabilities(b'bla'))
214

    
215
    def test_caps(self):
216
        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la'))
217
        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la\n'))
218
        self.assertEqual((b'bla', [b'la', b'la']), extract_capabilities(b'bla\0la la'))
219

    
220
    def test_plain_want_line(self):
221
        self.assertEqual((b'want bla', []), extract_want_line_capabilities(b'want bla'))
222

    
223
    def test_caps_want_line(self):
224
        self.assertEqual((b'want bla', [b'la']),
225
                extract_want_line_capabilities(b'want bla la'))
226
        self.assertEqual((b'want bla', [b'la']),
227
                extract_want_line_capabilities(b'want bla la\n'))
228
        self.assertEqual((b'want bla', [b'la', b'la']),
229
                extract_want_line_capabilities(b'want bla la la'))
230

    
231
    def test_ack_type(self):
232
        self.assertEqual(SINGLE_ACK, ack_type([b'foo', b'bar']))
233
        self.assertEqual(MULTI_ACK, ack_type([b'foo', b'bar', b'multi_ack']))
234
        self.assertEqual(MULTI_ACK_DETAILED,
235
                          ack_type([b'foo', b'bar', b'multi_ack_detailed']))
236
        # choose detailed when both present
237
        self.assertEqual(MULTI_ACK_DETAILED,
238
                          ack_type([b'foo', b'bar', b'multi_ack',
239
                                    b'multi_ack_detailed']))
240

    
241

    
242
class BufferedPktLineWriterTests(TestCase):
243

    
244
    def setUp(self):
245
        TestCase.setUp(self)
246
        self._output = BytesIO()
247
        self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
248

    
249
    def assertOutputEquals(self, expected):
250
        self.assertEqual(expected, self._output.getvalue())
251

    
252
    def _truncate(self):
253
        self._output.seek(0)
254
        self._output.truncate()
255

    
256
    def test_write(self):
257
        self._writer.write(b'foo')
258
        self.assertOutputEquals(b'')
259
        self._writer.flush()
260
        self.assertOutputEquals(b'0007foo')
261

    
262
    def test_write_none(self):
263
        self._writer.write(None)
264
        self.assertOutputEquals(b'')
265
        self._writer.flush()
266
        self.assertOutputEquals(b'0000')
267

    
268
    def test_flush_empty(self):
269
        self._writer.flush()
270
        self.assertOutputEquals(b'')
271

    
272
    def test_write_multiple(self):
273
        self._writer.write(b'foo')
274
        self._writer.write(b'bar')
275
        self.assertOutputEquals(b'')
276
        self._writer.flush()
277
        self.assertOutputEquals(b'0007foo0007bar')
278

    
279
    def test_write_across_boundary(self):
280
        self._writer.write(b'foo')
281
        self._writer.write(b'barbaz')
282
        self.assertOutputEquals(b'0007foo000abarba')
283
        self._truncate()
284
        self._writer.flush()
285
        self.assertOutputEquals(b'z')
286

    
287
    def test_write_to_boundary(self):
288
        self._writer.write(b'foo')
289
        self._writer.write(b'barba')
290
        self.assertOutputEquals(b'0007foo0009barba')
291
        self._truncate()
292
        self._writer.write(b'z')
293
        self._writer.flush()
294
        self.assertOutputEquals(b'0005z')
295

    
296

    
297
class PktLineParserTests(TestCase):
298

    
299
    def test_none(self):
300
        pktlines = []
301
        parser = PktLineParser(pktlines.append)
302
        parser.parse(b"0000")
303
        self.assertEqual(pktlines, [None])
304
        self.assertEqual(b"", parser.get_tail())
305

    
306
    def test_small_fragments(self):
307
        pktlines = []
308
        parser = PktLineParser(pktlines.append)
309
        parser.parse(b"00")
310
        parser.parse(b"05")
311
        parser.parse(b"z0000")
312
        self.assertEqual(pktlines, [b"z", None])
313
        self.assertEqual(b"", parser.get_tail())
314

    
315
    def test_multiple_packets(self):
316
        pktlines = []
317
        parser = PktLineParser(pktlines.append)
318
        parser.parse(b"0005z0006aba")
319
        self.assertEqual(pktlines, [b"z", b"ab"])
320
        self.assertEqual(b"a", parser.get_tail())