socketsession.cc 16.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
// PERFORMANCE OF THIS SOFTWARE.

15
16
#include <config.h>

17
18
#include <unistd.h>

19
20
#include <sys/types.h>
#include <sys/socket.h>
21
#include <sys/uio.h>
22
23
24
25
#include <sys/un.h>

#include <netinet/in.h>

26
#include <fcntl.h>
27
28
#include <stdint.h>

29
30
#include <cerrno>
#include <csignal>
31
#include <cstddef>
32
#include <cstring>
33
#include <cassert>
34

35
36
37
#include <string>
#include <vector>

38
39
#include <boost/noncopyable.hpp>

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <exceptions/exceptions.h>

#include <util/buffer.h>

#include "fd_share.h"
#include "socketsession.h"
#include "sockaddr_util.h"

using namespace std;

namespace isc {
namespace util {
namespace io {

using namespace internal;

56
// The expected max size of the session header: 2-byte header length,
57
58
59
// 6 32-bit fields, and 2 sockaddr structure. (see the SocketSessionUtility
// overview description in the header file).  sizeof sockaddr_storage
// should be the possible max of any sockaddr structure
60
const size_t DEFAULT_HEADER_BUFLEN = sizeof(uint16_t) + sizeof(uint32_t) * 6 +
61
62
    sizeof(struct sockaddr_storage) * 2;

63
64
65
66
67
68
// The allowable maximum size of data passed with the socket FD.  For now
// we use a fixed value of 65535, the largest possible size of valid DNS
// messages.  We may enlarge it or make it configurable as we see the need
// for more flexibility.
const int MAX_DATASIZE = 65535;

69
// The initial buffer size for receiving socket session data in the receiver.
70
71
72
73
74
75
76
77
78
79
80
// This value is the maximum message size of DNS messages carried over UDP
// (without EDNS).  In our expected usage (at the moment) this should be
// sufficiently large (the expected data is AXFR/IXFR query or an UPDATE
// requests.  The former should be generally quite small.  While the latter
// could be large, it would often be small enough for a single UDP message).
// If it turns out that there are many exceptions, we may want to extend
// the class so that this value can be customized.  Note that the buffer
// will be automatically extended for longer data and this is only about
// efficiency.
const size_t INITIAL_BUFSIZE = 512;

81
// The (default) socket buffer size for the forwarder and receiver.  This is
82
83
// chosen to be sufficiently large to store two full-size DNS messages.  We
// may want to customize this value in future.
84
85
86
const int SOCKSESSION_BUFSIZE = (DEFAULT_HEADER_BUFLEN + MAX_DATASIZE) * 2;

struct SocketSessionForwarder::ForwarderImpl {
87
    ForwarderImpl() : fd_(-1), buf_(DEFAULT_HEADER_BUFLEN) {}
88
89
90
91
92
    struct sockaddr_un sock_un_;
    socklen_t sock_un_len_;
    int fd_;
    OutputBuffer buf_;
};
93

94
95
96
SocketSessionForwarder::SocketSessionForwarder(const std::string& unix_file) :
    impl_(NULL)
{
97
98
    // We need to filter SIGPIPE for subsequent push().  See the class
    // description.
99
100
101
102
    if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
        isc_throw(Unexpected, "Failed to filter SIGPIPE: " << strerror(errno));
    }

103
104
105
106
107
108
109
    ForwarderImpl impl;
    if (sizeof(impl.sock_un_.sun_path) - 1 < unix_file.length()) {
        isc_throw(SocketSessionError,
                  "File name for a UNIX domain socket is too long: " <<
                  unix_file);
    }
    impl.sock_un_.sun_family = AF_UNIX;
110
111
112
    // the copy should be safe due to the above check, but we'd be rather
    // paranoid about making it 100% sure even if the check has a bug (with
    // triggering the assertion in the worse case)
113
114
115
    strncpy(impl.sock_un_.sun_path, unix_file.c_str(),
            sizeof(impl.sock_un_.sun_path));
    assert(impl.sock_un_.sun_path[sizeof(impl.sock_un_.sun_path) - 1] == '\0');
116
117
    impl.sock_un_len_ = offsetof(struct sockaddr_un, sun_path) +
        unix_file.length();
118
#ifdef HAVE_SA_LEN
119
    impl.sock_un_.sun_len = impl.sock_un_len_;
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#endif
    impl.fd_ = -1;

    impl_ = new ForwarderImpl;
    *impl_ = impl;
}

SocketSessionForwarder::~SocketSessionForwarder() {
    if (impl_->fd_ != -1) {
        close();
    }
    delete impl_;
}

void
135
SocketSessionForwarder::connectToReceiver() {
136
    if (impl_->fd_ != -1) {
137
        isc_throw(BadValue, "Duplicate connect to UNIX domain "
138
139
140
141
142
143
144
145
                  "endpoint " << impl_->sock_un_.sun_path);
    }

    impl_->fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
    if (impl_->fd_ == -1) {
        isc_throw(SocketSessionError, "Failed to create a UNIX domain socket: "
                  << strerror(errno));
    }
146
147
148
149
150
151
152
153
154
155
156
157
    // Make the socket non blocking
    int fcntl_flags = fcntl(impl_->fd_, F_GETFL, 0);
    if (fcntl_flags != -1) {
        fcntl_flags |= O_NONBLOCK;
        fcntl_flags = fcntl(impl_->fd_, F_SETFL, fcntl_flags);
    }
    if (fcntl_flags == -1) {
        close();   // note: this is the internal method, not ::close()
        isc_throw(SocketSessionError,
                  "Failed to make UNIX domain socket non blocking: " <<
                  strerror(errno));
    }
158
159
160
161
162
163
164
165
166
167
    // Ensure the socket send buffer is large enough.  If we can't get the
    // current size, simply set the sufficient size.
    int sndbuf_size;
    socklen_t sndbuf_size_len = sizeof(sndbuf_size);
    if (getsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &sndbuf_size,
                   &sndbuf_size_len) == -1 ||
        sndbuf_size < SOCKSESSION_BUFSIZE) {
        if (setsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &SOCKSESSION_BUFSIZE,
                       sizeof(SOCKSESSION_BUFSIZE)) == -1) {
            close();
168
169
170
            isc_throw(SocketSessionError,
                      "Failed to set send buffer size to " <<
                          SOCKSESSION_BUFSIZE);
171
        }
172
    }
173
174
    if (connect(impl_->fd_, convertSockAddr(&impl_->sock_un_),
                impl_->sock_un_len_) == -1) {
175
        close();
176
177
178
179
180
181
182
183
184
        isc_throw(SocketSessionError, "Failed to connect to UNIX domain "
                  "endpoint " << impl_->sock_un_.sun_path << ": " <<
                  strerror(errno));
    }
}

void
SocketSessionForwarder::close() {
    if (impl_->fd_ == -1) {
185
        isc_throw(BadValue, "Attempt of close before connect");
186
187
188
189
190
191
    }
    ::close(impl_->fd_);
    impl_->fd_ = -1;
}

void
192
SocketSessionForwarder::push(int sock, int family, int type, int protocol,
193
194
195
196
                             const struct sockaddr& local_end,
                             const struct sockaddr& remote_end,
                             const void* data, size_t data_len)
{
197
    if (impl_->fd_ == -1) {
198
        isc_throw(BadValue, "Attempt of push before connect");
199
200
201
202
    }
    if ((local_end.sa_family != AF_INET && local_end.sa_family != AF_INET6) ||
        (remote_end.sa_family != AF_INET && remote_end.sa_family != AF_INET6))
    {
203
        isc_throw(BadValue, "Invalid address family: must be "
204
205
206
207
208
                  "AF_INET or AF_INET6; " <<
                  static_cast<int>(local_end.sa_family) << ", " <<
                  static_cast<int>(remote_end.sa_family) << " given");
    }
    if (family != local_end.sa_family || family != remote_end.sa_family) {
209
        isc_throw(BadValue, "Inconsistent address family: must be "
210
211
212
213
                  << static_cast<int>(family) << "; "
                  << static_cast<int>(local_end.sa_family) << ", "
                  << static_cast<int>(remote_end.sa_family) << " given");
    }
214
    if (data_len == 0 || data == NULL) {
215
        isc_throw(BadValue, "Data for a socket session must not be empty");
216
    }
217
    if (data_len > MAX_DATASIZE) {
218
        isc_throw(BadValue, "Invalid socket session data size: " <<
219
220
                  data_len << ", must not exceed " << MAX_DATASIZE);
    }
221

222
223
224
    if (send_fd(impl_->fd_, sock) != 0) {
        isc_throw(SocketSessionError, "FD passing failed: " <<
                  strerror(errno));
225
    }
226
227
228
229
230
231

    impl_->buf_.clear();
    // Leave the space for the header length
    impl_->buf_.skip(sizeof(uint16_t));
    // Socket properties: family, type, protocol
    impl_->buf_.writeUint32(static_cast<uint32_t>(family));
232
    impl_->buf_.writeUint32(static_cast<uint32_t>(type));
233
234
235
236
237
238
239
    impl_->buf_.writeUint32(static_cast<uint32_t>(protocol));
    // Local endpoint
    impl_->buf_.writeUint32(static_cast<uint32_t>(getSALength(local_end)));
    impl_->buf_.writeData(&local_end, getSALength(local_end));
    // Remote endpoint
    impl_->buf_.writeUint32(static_cast<uint32_t>(getSALength(remote_end)));
    impl_->buf_.writeData(&remote_end, getSALength(remote_end));
240
241
242
243
    // Data length.  Must be fit uint32 due to the range check above.
    const uint32_t data_len32 = static_cast<uint32_t>(data_len);
    assert(data_len == data_len32); // shouldn't cause overflow.
    impl_->buf_.writeUint32(data_len32);
244
245
246
    // Write the resulting header length at the beginning of the buffer
    impl_->buf_.writeUint16At(impl_->buf_.getLength() - sizeof(uint16_t), 0);

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    const struct iovec iov[2] = {
        { const_cast<void*>(impl_->buf_.getData()), impl_->buf_.getLength() },
        { const_cast<void*>(data), data_len }
    };
    const int cc = writev(impl_->fd_, iov, 2);
    if (cc != impl_->buf_.getLength() + data_len) {
        if (cc < 0) {
            isc_throw(SocketSessionError,
                      "Write failed in forwarding a socket session: " <<
                      strerror(errno));
        }
        isc_throw(SocketSessionError,
                  "Incomplete write in forwarding a socket session: " << cc <<
                  "/" << (impl_->buf_.getLength() + data_len));
    }
262
263
264
265
266
}

SocketSession::SocketSession(int sock, int family, int type, int protocol,
                             const sockaddr* local_end,
                             const sockaddr* remote_end,
267
                             const void* data, size_t data_len) :
268
269
    sock_(sock), family_(family), type_(type), protocol_(protocol),
    local_end_(local_end), remote_end_(remote_end),
270
    data_(data), data_len_(data_len)
271
{
272
273
274
275
276
277
278
279
280
    if (local_end == NULL || remote_end == NULL) {
        isc_throw(BadValue, "sockaddr must be non NULL for SocketSession");
    }
    if (data_len == 0) {
        isc_throw(BadValue, "data_len must be non 0 for SocketSession");
    }
    if (data == NULL) {
        isc_throw(BadValue, "data must be non NULL for SocketSession");
    }
281
282
}

283
284
struct SocketSessionReceiver::ReceiverImpl {
    ReceiverImpl(int fd) : fd_(fd),
285
286
                           sa_local_(convertSockAddr(&ss_local_)),
                           sa_remote_(convertSockAddr(&ss_remote_)),
287
288
                           header_buf_(DEFAULT_HEADER_BUFLEN),
                           data_buf_(INITIAL_BUFSIZE)
289
290
291
292
    {
        if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &SOCKSESSION_BUFSIZE,
                       sizeof(SOCKSESSION_BUFSIZE)) == -1) {
            isc_throw(SocketSessionError,
293
294
                      "Failed to set receive buffer size to " <<
                          SOCKSESSION_BUFSIZE);
295
296
        }
    }
297
298

    const int fd_;
299
    struct sockaddr_storage ss_local_; // placeholder for local endpoint
300
    struct sockaddr* const sa_local_;
301
    struct sockaddr_storage ss_remote_; // placeholder for remote endpoint
302
303
    struct sockaddr* const sa_remote_;

304
    // placeholder for session header and data
305
306
    vector<uint8_t> header_buf_;
    vector<uint8_t> data_buf_;
307
308
};

309
310
SocketSessionReceiver::SocketSessionReceiver(int fd) :
    impl_(new ReceiverImpl(fd))
311
312
313
{
}

314
SocketSessionReceiver::~SocketSessionReceiver() {
315
316
317
    delete impl_;
}

318
319
320
321
322
323
324
325
326
327
328
329
namespace {
// A shortcut to throw common exception on failure of recv(2)
void
readFail(int actual_len, int expected_len) {
    if (expected_len < 0) {
        isc_throw(SocketSessionError, "Failed to receive data from "
                  "SocketSessionForwarder: " << strerror(errno));
    }
    isc_throw(SocketSessionError, "Incomplete data from "
              "SocketSessionForwarder: " << actual_len << "/" <<
              expected_len);
}
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

// A helper container for a (socket) file descriptor used in
// SocketSessionReceiver::pop that ensures the socket is closed unless it
// can be safely passed to the caller via release().
struct ScopedSocket : boost::noncopyable {
    ScopedSocket(int fd) : fd_(fd) {}
    ~ScopedSocket() {
        if (fd_ >= 0) {
            close(fd_);
        }
    }
    int release() {
        const int fd = fd_;
        fd_ = -1;
        return (fd);
    }
    int fd_;
};
348
349
}

350
SocketSession
351
SocketSessionReceiver::pop() {
352
353
    ScopedSocket passed_sock(recv_fd(impl_->fd_));
    if (passed_sock.fd_ == FD_SYSTEM_ERROR) {
354
355
        isc_throw(SocketSessionError, "Receiving a forwarded FD failed: " <<
                  strerror(errno));
356
    } else if (passed_sock.fd_ < 0) {
357
358
        isc_throw(SocketSessionError, "No FD forwarded");
    }
359
360

    uint16_t header_len;
361
    const int cc_hlen = recv(impl_->fd_, &header_len, sizeof(header_len),
362
                        MSG_WAITALL);
363
364
365
    if (cc_hlen < sizeof(header_len)) {
        readFail(cc_hlen, sizeof(header_len));
    }
366
    header_len = InputBuffer(&header_len, sizeof(header_len)).readUint16();
367
368
369
370
    if (header_len > DEFAULT_HEADER_BUFLEN) {
        isc_throw(SocketSessionError, "Too large header length: " <<
                  header_len);
    }
371
372
    impl_->header_buf_.clear();
    impl_->header_buf_.resize(header_len);
373
374
375
376
377
    const int cc_hdr = recv(impl_->fd_, &impl_->header_buf_[0], header_len,
                            MSG_WAITALL);
    if (cc_hdr < header_len) {
        readFail(cc_hdr, header_len);
    }
378
379

    InputBuffer ibuffer(&impl_->header_buf_[0], header_len);
380
381
382
383
384
385
386
387
388
    try {
        const int family = static_cast<int>(ibuffer.readUint32());
        if (family != AF_INET && family != AF_INET6) {
            isc_throw(SocketSessionError,
                      "Unsupported address family is passed: " << family);
        }
        const int type = static_cast<int>(ibuffer.readUint32());
        const int protocol = static_cast<int>(ibuffer.readUint32());
        const socklen_t local_end_len = ibuffer.readUint32();
389
390
391
392
393
        const socklen_t endpoint_minlen = (family == AF_INET) ?
            sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
        if (local_end_len < endpoint_minlen ||
            local_end_len > sizeof(impl_->ss_local_)) {
            isc_throw(SocketSessionError, "Invalid local SA length: " <<
394
395
396
397
                      local_end_len);
        }
        ibuffer.readData(&impl_->ss_local_, local_end_len);
        const socklen_t remote_end_len = ibuffer.readUint32();
398
399
400
        if (remote_end_len < endpoint_minlen ||
            remote_end_len > sizeof(impl_->ss_remote_)) {
            isc_throw(SocketSessionError, "Invalid remote SA length: " <<
401
402
403
                      remote_end_len);
        }
        ibuffer.readData(&impl_->ss_remote_, remote_end_len);
404
405
        if (family != impl_->sa_local_->sa_family ||
            family != impl_->sa_remote_->sa_family) {
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            isc_throw(SocketSessionError, "SA family inconsistent: " <<
                      static_cast<int>(impl_->sa_local_->sa_family) << ", " <<
                      static_cast<int>(impl_->sa_remote_->sa_family) <<
                      " given, must be " << family);
        }
        const size_t data_len = ibuffer.readUint32();
        if (data_len == 0 || data_len > MAX_DATASIZE) {
            isc_throw(SocketSessionError,
                      "Invalid socket session data size: " << data_len <<
                      ", must be > 0 and <= " << MAX_DATASIZE);
        }

        impl_->data_buf_.clear();
        impl_->data_buf_.resize(data_len);
        const int cc_data = recv(impl_->fd_, &impl_->data_buf_[0], data_len,
                                 MSG_WAITALL);
        if (cc_data < data_len) {
            readFail(cc_data, data_len);
        }

426
        return (SocketSession(passed_sock.release(), family, type, protocol,
427
428
                              impl_->sa_local_, impl_->sa_remote_,
                              &impl_->data_buf_[0], data_len));
429
430
431
432
433
434
    } catch (const InvalidBufferPosition& ex) {
        // We catch the case where the given header is too short and convert
        // the exception to SocketSessionError.
        isc_throw(SocketSessionError, "bogus socket session header: " <<
                  ex.what());
    }
435
436
437
438
439
}

}
}
}