/* Copyright (c) 1995,1996,1997 NEC Corporation.  All rights reserved.       */
/*                                                                           */
/* The redistribution, use and modification in source or binary forms of     */
/* this software is subject to the conditions set forth in the copyright     */
/* document ("Copyright") included with this distribution.                   */

/*
 * $Id: socket.c,v 1.94.2.2.2.9 1998/07/19 22:51:26 wlu Exp $
 */

/* This file contains all of the socket creation and configuration functions */
/* that are used by the daemon.  There are a couple of utility functions     */
/* here becuase they deal with the sockets themselves...                     */
#include "socks5p.h"
#include "daemon.h"
#include "validate.h"
#include "msgids.h"
#include "socket.h"
#include "proxy.h"
#include "sema.h"
#include "log.h"
#include "threads.h"
#include "sigfix.h"

#if defined(linux) && defined(USE_THREADS)
#undef SIGUSR1
#define SIGUSR1 SIGUNUSED
#endif

/* Two server modes (threading & preforking) work in a master/slave mode,    */
/* these macros make decisions based on that status easier to read           */
#define ISMASTER() (!iamachild && (servermode == PREFORKING || servermode == THREADED))
#define ISSLAVE()  (iamachild  && (servermode == PREFORKING || servermode == THREADED))

static int acceptor  = 0;  /* the pid of the accepting process if threaded   */
static int nconns    = 0;  /* the number of proxies the daemon has had...    */
static int nchildren = 0;  /* the number of children the daemon has had...   */
static int iamachild = 0;  /* Am I a child, or the parent...                 */
static void *asem = NULL;
static S5IOHandle in = S5InvalidIOHandle;

static sig_atomic_t hadfatalsig = 0;  /* Has a sighup/sigusr1 occured?       */
static sig_atomic_t hadsigint   = 0;  /* Has a sigint         occured?       */
static sig_atomic_t hadresetsig = 0;  /* Has a sigusr1        occured?       */

/* A function to collect dead children before they become zombies...  Also,  */
/* It keeps track of how many children are around for preforking...          */
/*                                                                           */
/* On a non-posix system, I suppose calling wait3 or waitpid in a signal     */
/* handler (according to APUE) might not be cool.  Someone will have to      */
/* tell me if this is ever the case, and I'll write code so that reaping is  */
/* done elsewhere.                                                           */
static RETSIGTYPE gravedigger(void) {
    int oerrno = errno, wval, wstatus;
    Sigset_t set = SigBlock(SIGCHLD);

    for (;;) {
#ifdef HAVE_WAITPID
	wval = waitpid(-1, &wstatus, WNOHANG);
#else
        wval = wait3(&wstatus, WNOHANG, NULL);
#endif

	switch (wval) {
	    case -1:
		if (errno == EINTR) continue;
	    case 0:
		errno = oerrno;
		SigUnblock(set);
		return;
	    default:
		if (servermode == THREADED) acceptor = 0;
		if (nchildren > 0) nchildren--;
	}
    }
}

/* Indicate that we had a sigint, so we exit & clean up later                */
static RETSIGTYPE markdone(void) {
    if (hadresetsig) hadresetsig = 0;
    else hadsigint = 1;
}

/* Indicate that we had a sighup, so we can re-read the config file later.   */
static RETSIGTYPE die(void) {
    hadfatalsig = 1;
}

/* Indicate that we had a sigusr, so we start a new server.                  */
static RETSIGTYPE reset(void) {
    hadresetsig = 1;
}

/* fork(), and do the right thing...the right thing is...handle errors and   */
/* incremenent nchildren...pretty simple...                                  */
static int DoFork(void) {
    int pid;
  
    if (nchildren >= nservers) {
	errno = EAGAIN;
        S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(10), 0, "Total number of children is %d", nchildren);
	return -1;
    }
  
    switch (pid = fork()) {
	case 0:
	    Signal(SIGHUP,  SIG_DFL);
	    Signal(SIGUSR1, SIG_DFL);
	    Signal(SIGINT,  SIG_DFL);
            Signal(SIGTERM, SIG_DFL);

	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(15), 0, "Child: Starting");
	    iamachild = 1;
	    return 0;
	case -1:
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR,     0, "Fork failed: %m");
	    return -1;
	default:
	    nchildren++; 
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(15), 0, "Parent: %d child%s", nchildren, (nchildren != 1)?"ren":"");
	    return pid;
    }
}

static int GetBindIntfc(S5NetAddr *bndaddr) {
    u_short bindport = 0;
    char *tmp = NULL, *tmpaddr = NULL, *tmpport = NULL;

    if (bindif) tmp = strdup(bindif);
    else {
        MUTEX_LOCK(env_mutex);
        tmp = getenv("SOCKS5_BINDINTFC");
        if (tmp) tmp = strdup(tmp);
        MUTEX_UNLOCK(env_mutex);
    }

    if (tmp) {
        if ((tmpport = strchr(tmp, ':'))) {
            *tmpport++ = '\0';
            if (*tmp) tmpaddr = tmp;
        } else {
            if (isdigit(*tmp) && !strchr(tmp, '.')) tmpport = tmp;
            else tmpaddr = tmp;
        }
    }
 
    if (!tmpaddr) {
       memset((char *)bndaddr, 0, sizeof(S5NetAddr));
       bndaddr->sin.sin_family       = AF_INET;
       bndaddr->sin.sin_addr.s_addr  = htonl(INADDR_ANY);
    } else if (lsName2Addr(tmpaddr, bndaddr) < 0 || bndaddr->sin.sin_addr.s_addr == INVALIDADDR ) {
        S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, " Invalid address %s specified", tmpaddr);
        return -1;
    }

    if (tmpport) lsName2Port(tmpport, "tcp", &bindport);
    else lsName2Port("socks", "tcp", &bindport);

    if (bindport == INVALIDPORT) bindport = htons(SOCKS_DEFAULT_PORT);

    lsAddrSetPort(bndaddr, bindport);

    if (tmp) free(tmp);
    return 0;
}

static void GetUdpPortRange(void) {
    char *tmp = NULL, *tmpport = NULL;

    MUTEX_LOCK(env_mutex);
    tmp = getenv("SOCKS5_UDPPORTRANGE");
    if (tmp) tmp = strdup(tmp);
    MUTEX_UNLOCK(env_mutex);

    if (tmp) {
        if ((tmpport = strchr(tmp, '-'))) *tmpport++ = '\0';

        if (*tmp && isdigit(*tmp)) ludpport = (u_short)atoi(tmp);
	else {
            S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, " Invalid udp port range %s", tmp);
	    free(tmp);
	    return;
	}

        if (tmpport && *tmpport && isdigit(*tmpport)) hudpport = (u_short)atoi(tmpport);
	free(tmp);
    }
}

/* Make a socket with the correct protocol (p), bind it to the right port    */
/* (n, the name or port, the default), and set the function to be called the */
/* socket becomes active (func)...store all this in an fdrec structure for   */
/* convenience's sake...if for any reason something fails and this enry      */
/* should become invalid, the r->fd should be set to -1, so other places     */
/* know to ignore it...                                                      */
static int MakeSocket(int start, S5IOHandle *infd) {
    S5NetAddr bndaddr;
    char *tmp = NULL;
    int on = 1;

    if (!start) {
	time_t now = time(NULL);
	char tbuf[1024];

	MUTEX_LOCK(lt_mutex);
	strftime(tbuf, sizeof(tbuf), "%c", localtime(&now));
	MUTEX_UNLOCK(lt_mutex);
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_INFO, MSGID_SERVER_RESTART, "Socks5 restarting at %s", tbuf);
    }

    MUTEX_LOCK(env_mutex);
    if ((tmp = getenv("SOCKS5_TIMEOUT")) && *tmp) idletimeout = atoi(tmp);
    MUTEX_UNLOCK(env_mutex);

    if (*infd != S5InvalidIOHandle) {
	return 0;
    }

    if (start) {
        if (GetBindIntfc(&bndaddr) < 0) goto cleanup;
	GetUdpPortRange();
    } 

    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(12), 0, "Socks5 attempting to run on interface %s:%d", ADDRANDPORT(&bndaddr));

    if ((*infd = socket(AF_INET, SOCK_STREAM, 0)) == S5InvalidIOHandle) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, "Socket failed for %s:%d: %m", ADDRANDPORT(&bndaddr));
	goto cleanup;
    }

    if (setsockopt(*infd, SOL_SOCKET, SO_REUSEADDR, (char *)&on, sizeof(int)) < 0) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0,  "Turning on address reuse failed for %s:%d: %m", ADDRANDPORT(&bndaddr));
	goto cleanup;
    }
  
    if (bind(*infd, (ss *)&bndaddr, sizeof(ssi)) < 0) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, MSGID_SERVER_SOCKS_BIND, "Bind failed for %s:%d: %m", ADDRANDPORT(&bndaddr));
	goto cleanup;
    }
  
    if (listen(*infd, 5) < 0) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, "Listen failed for %s:%d: %m", ADDRANDPORT(&bndaddr));
	goto cleanup;
    }

    /* If we bound, we're the owner, so put our pid in the pidfile.          */
#ifndef DONT_STORE_PID
    if (start) {
	char abuf[64], *myfl, *ofile = NULL;
	S5IOHandle fd = S5InvalidIOHandle;
	pid_t pid = getpid();
	struct stat sbuf;
      
	MUTEX_LOCK(env_mutex);
	myfl = getenv("SOCKS5_PIDFILE");
	myfl = myfl?strdup(myfl):strdup(SRVPID_FILE);
	MUTEX_UNLOCK(env_mutex);

	if ((ofile = malloc(strlen(myfl)+7))) {
	    sprintf(ofile, "%s-%d", myfl, (int)ntohs(bndaddr.sin.sin_port));
	    free(myfl);
	    myfl = ofile;
	} else {
	    free(myfl);
	    myfl = NULL;
	}
      
	if (myfl) {
	    int flags = O_WRONLY | O_CREAT | O_TRUNC;
	    /* Open exclusively if the file doesn't exist or if it does, it  */
	    /* is a link, and someone else owns it                           */
	    if (lstat(myfl, &sbuf) || (S_ISLNK(sbuf.st_mode) && geteuid() != sbuf.st_uid)) flags |= O_EXCL;
	    fd = open(myfl, flags, 0644);
	}

	if (fd == S5InvalidIOHandle) {
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_WARNING, 0, "Error: Failed to open pid file: %s: %m", myfl?myfl:"(null)");
	} else {
	    sprintf(abuf, "%d\n", (int)pid);
	    WRITEFILE(fd, abuf, strlen(abuf));
	    close(fd);
	}
	
	if (myfl) free(myfl);
    }
#endif

    return 0;
    
  cleanup:
    if (*infd != S5InvalidIOHandle) CLOSESOCKET(*infd);
    *infd = S5InvalidIOHandle;
    return -1;
}

/* This is called whenever an error occurs that requires a signal to fix.    */
/* If the server is preforking or threaded, the parent begins in this state, */
/* and only children come out of it... If it is normal, this state is        */
/* reached by having a fatal call to accept...                               */
/*                                                                           */
/* Children come here sometimes when they've had fatal signals and need to   */
/* be killed off...                                                          */
static int GetSignals(void *asem, S5IOHandle *infd) {
    Sigset_t set;

    if (iamachild) {
        S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(10), 0, "Childs exiting");
        exit(0);
    }

    /* Block appropriate signals here so that we don't get interrupted while */
    /* we're forking or setting things up...                                 */
    for (set = SigBlock(SIGHUP); ; ) {
	/* Do our thing if everything is ok...                               */
	if (*infd != S5InvalidIOHandle) {
	    switch (servermode) {
		case THREADED:
		    if (acceptor == 0) acceptor = DoFork();
		    if (iamachild) goto done;
		    break;
		case PREFORKING:
		    while (DoFork() > 0);
		    if (iamachild) goto done;
		    break;
	    }
	}

        if (servermode == THREADED && (acceptor < 0 && errno != EAGAIN)) {
            S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, "server exiting: fork failed");
            hadsigint = 1;
        }

	/* Wait for any signal to arrive, esp SIGCHLD and SIGHUP.  SIGHUP    */
	/* will cause a re-read of the config file, everything else: loop.   */
	if (!hadfatalsig && !hadsigint && !hadresetsig && *infd != S5InvalidIOHandle) {
	    SigPause();
	}
	
	if (hadsigint) {
	    time_t now = time(NULL);
	    char tbuf[1024];

	    MUTEX_LOCK(lt_mutex);
	    strftime(tbuf, sizeof(tbuf), "%c", localtime(&now));
	    MUTEX_UNLOCK(lt_mutex);
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_INFO, MSGID_SERVER_STOP, "Socks5 Exiting at: %s", tbuf);

	    hadresetsig = 0;
            kill(-getpid(), SIGINT);
	    if (ISMASTER()) { semdestroy(asem); }
	    exit(0);
	}

	if (!hadfatalsig && !hadresetsig) {
            /* SIGCHLD received...                                           */
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(15), 0, "Parent reaped? (%d child%s)", nchildren, (nchildren != 1)?"ren":"");
            continue;
	}
	
	if (hadfatalsig) {
	    if (servermode == PREFORKING || servermode == THREADED) {
		/* Kill whoever is accepting (all of them if preforking) */
		/* and reset the semaphore...                            */
		hadresetsig = 1;
                kill(-getpid(), SIGINT);
		if (ISMASTER()) { semreset(asem, 1); } 
		acceptor = 0;
	    }

	    ReadConfig();
	}
 
	if (hadresetsig && servermode == THREADED) acceptor = 0;

	hadfatalsig = 0;
	hadresetsig = 0;

        if (servermode == NORMAL) break;
    }

  done:
    SigUnblock(set);
    return 0;
}

static void DoWork(S5IOHandle sd) {
    int eval;

    if (servermode == NORMAL && (eval = DoFork()) != 0) {
	CLOSESOCKET(sd);
	if (eval > 0 || errno == EAGAIN) return;
	exit(-1);
    }

    eval = HandlePxyConnection(sd);

    if (servermode == PREFORKING || servermode == THREADED) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(10), 0, "Accept: Done with connection...");
	return;
    }
    
    exit(eval);
}

static void DoThreadWork(S5IOHandle sd) {
    int len;
    S5NetAddr source;
    S5IOHandle afd = sd;

    for (;;) {
	DoWork(afd);

	MUTEX_LOCK(accept_mutex);

	if (semacquire(asem)) {
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(2), 0, "DoThreadWork: Semaphore failure.");
	    MUTEX_UNLOCK(accept_mutex);
	    THREAD_EXIT(-1);
	}
	
	len = sizeof(source);
	memset(&source, 0, len);
	while ((afd = accept(in, &source.sa, &len)) == S5InvalidIOHandle && errno == EINTR);
	semrelease(asem);
	MUTEX_UNLOCK(accept_mutex);

	if (afd == S5InvalidIOHandle) {
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(2), 0, "DoThreadWork: accept failure.");
	    THREAD_EXIT(-1);
	}
    }
}

/* Get a connection from our input socket.  Then based on the socket's       */
/* data, handle the connection correctly, and go on waiting for more         */
/* connetions...                                                             */
static void GetNetConnection(void) {
    Sigset_t set = SigBlock(ISMASTER()?SIGUSR1:SIGHUP);
    S5IOHandle afd;
    S5NetAddr source;
    int aerrno;

#if !defined(USE_THREADS) || !defined(HAVE_PTHREAD_H)
    if (servermode == THREADED) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(0), 0, "Warning: Attempt to run server in threaded mode when threads were not a compile time option");
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(0), 0, "Warning: Running as a normal standalone server");
        SigUnblock(set);
        set = SigBlock(SIGHUP);
	servermode = NORMAL;
    }
#endif

    Signal(SIGUSR1, reset);
    Signal(SIGHUP,  die);
    Signal(SIGCHLD, gravedigger);

    ReadConfig();

    if (MakeSocket(1, &in) < 0) {
        S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, "Accept: Failed to make listening socket");
        exit(-1);
    }

    Signal(SIGINT,  markdone);
    Signal(SIGTERM, markdone);

    if (ISMASTER()) {
	asem = semcreate(1);
	GetSignals(asem, &in);
    }

    if (ISMASTER()) {
	S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(0), 0, "Error: Master reached slave code in GetNetConnection");
	exit(EXIT_ERR);
    }

    if (ISSLAVE()) {
	hadsigint = 0;
	hadresetsig = 0;
    }

    for (;;) {
	int len = sizeof(S5NetAddr);

	/* If an important signal has arrived or the acc fd is corrupted,    */
	/* got into the signal waiting state (and possibly exit - if child). */
	if (hadfatalsig || hadsigint || in == S5InvalidIOHandle) {
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(2), 0, "Accept: Processing exception");
	    GetSignals(asem, &in);
	    hadfatalsig = 0;
	}
	
	if (SigPending(ISSLAVE()?SIGUSR1:SIGHUP)) {
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(2), 0, "Accept: Waiting for a pending fatal signal...");
	    while (!hadfatalsig) SigPause();
	    continue;
	}
	
	/* Try to accept a connection.  We may receive a signal (HUP or      */
	/* USR1) here, if that happens, accept should return -1, with errno  */
	/* set to EINTR.  We'll handle that later, after we release the      */
	/* semaphore.  If we've already received a signal, don't bother      */
	/* accepting the connection, just set pri->in to -1 and errno to     */
	/* EINTR, to simulate having received the signal here.               */
	if (hadfatalsig) {
	    afd    = S5InvalidIOHandle;
	    aerrno = EINTR;
	} else {
	    /* For some reason (thanks to Rich Stevens for pointing this     */
	    /* out), System 5 is unhappy about a bunch of people calling     */
	    /* accept.  So we'll add some locks around it to synchronize     */
	    /* access...                                                     */
	    if (servermode == THREADED) MUTEX_LOCK(accept_mutex);

	    if (ISSLAVE()) {
		S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(15), 0, "Accept: Acquiring semaphore");
		SigUnblock(set);

		if (semacquire(asem)) {
		    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(2), 0, "Accept: Semaphore failure.");
	    	    if (servermode == THREADED) {
			MUTEX_UNLOCK(accept_mutex);
			if (nconns < nthreads) kill(getppid(), SIGUSR1);
			THREAD_EXIT(-1);
		    } else { 
			CLOSESOCKET(in);
			exit(-1);
		    }
		}

		set = SigBlock((servermode == NORMAL)?SIGHUP:SIGUSR1);
	    }

	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(10), 0, "Accept: Waiting on accept or a signal");

	    SigUnblock(set);
	    afd = accept(in, &source.sa, &len);
	    if (afd == S5InvalidIOHandle) aerrno = errno;
	    set = SigBlock((servermode == NORMAL)?SIGHUP:SIGUSR1);
	}

	/* Since we've got the connection, release the semaphore.            */
	/*                                                                   */
	/* We don't have to worry about releasing a reset semaphore, if we   */
	/* received a signal in the meantime, since semreset *should* create */
	/* a whole new semaphore, so we'll be releasing one which no one     */
	/* else looks at anymore.                                            */
	if (ISSLAVE()) {
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_DEBUG(15), 0, "Accept: Releasing semaphore");
	    semrelease(asem);
	    if (servermode == THREADED) MUTEX_UNLOCK(accept_mutex);
	}

	/* Do the work according to the protocol to be passed on pri->in.    */
	/* When we're done, clean things up so we can do it all again...     */
	if (afd == S5InvalidIOHandle) {
	    /* We have to make sure the error wasn't too serious.  If it     */
	    /* was, quit, unless we are the parent, in which caes we we wait */
	    /* for a HUP to tell us things are fixed.                        */
	    if (aerrno == EINTR) continue;

	    errno = aerrno;
	    S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, MSGID_SERVER_SOCKS_ACCEPT, "Accept: Accept failed: %m");

            if (ISSLAVE()) {
		if (servermode == THREADED) {
		    if (nconns < nthreads) kill(getppid(), SIGUSR1);
		    THREAD_EXIT(-1);
		} else {
		    CLOSESOCKET(in);
		    exit(-1);
		}
	    }

            /* It is NORMAL mode and system resource is exhausted. sleep     */
	    /* while and continue...                                         */
	    sleep(180);
	} else if (servermode == THREADED) {
#if defined(USE_THREADS) && defined(HAVE_PTHREAD_H)
	    THREAD_T tid;
	    ATTR_T attr;
	    sigset_t set, oset;

	    if ((nconns + 1) >= nthreads) {
		kill(getppid(), SIGUSR1);
		nconns++;
		DoThreadWork(afd);
	    }

	    sigemptyset(&set);

            THREAD_ATTR_INIT(attr);
            THREAD_ATTR_SETSTACKSIZE(&attr, 51200);
            THREAD_ATTR_SETSCOPE(attr, PTHREAD_SCOPE_SYSTEM);
            THREAD_ATTR_SETDETACHSTATE(attr, PTHREAD_CREATE_DETACHED);
            THREAD_SIGMASK(SIG_BLOCK, set, oset);

	    
            if (THREAD_CREATE(&tid, attr, (void *(*)P((void *)))DoThreadWork, (void *)afd) < 0) {
		S5LogUpdate(S5LogDefaultHandle, S5_LOG_ERROR, 0, "Accept: Thread creation failed: %m");

		kill(getppid(), SIGUSR1);
		nconns++;
		DoThreadWork(afd);
	    } else nconns++;

            THREAD_DETACH(tid);
	    THREAD_SIGMASK(SIG_UNBLOCK, set, oset);
#endif
	    afd = S5InvalidIOHandle;
	} else {
	    if (servermode == SINGLESHOT) CLOSESOCKET(in);
    	    nconns++;
	    DoWork(afd);
	}
    }
}

/* Setup a connection which has already been set up for us by inetd.         */
/* Basically, we're just filling in the right structures and calling the     */
/* work function, HandleProxyConnection...                                   */
void GetStdioConnection(void) {
    ReadConfig();
    fclose(stdout);
    fclose(stderr);
    DoWork(STDIN_FILENO);
}

void GetConnection() {
    Signal(SIGPIPE, SIG_IGN);

    if (servermode == INETD) GetStdioConnection();
    else                     GetNetConnection();
}
