#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <sys/errno.h>
#include <stdio.h>
#include <netdb.h>
#include <arpa/nameser.h>
#include <resolv.h>
#include <pwd.h>
#include "socks.h"

extern int 			errno;
extern char			*getenv();
static struct sockaddr_in	cursin;
static unsigned long		SocksHost;

static int SendDst(s, dst)
int	s;
Socks_t	*dst;
{
	Write8(s, dst->version);
	Write8(s, dst->cmd);
	Write32(s, dst->port);
	Write32(s, dst->host);
}

static int GetDst(s, dst)
int	s;
Socks_t	*dst;
{
	Read8(s, dst->version);
	Read8(s, dst->cmd);
	Read32(s, dst->port);
	Read32(s, dst->host);
}

Rconnect(sock, sin, size)
int			sock;
struct sockaddr_in	*sin;
int			size;
{
	Socks_t			dst;
	struct sockaddr_in	nsin;
	struct passwd		*pw;

	if ((size != sizeof(struct sockaddr_in))||(sin->sin_family != AF_INET)){
		errno = EAFNOSUPPORT;
		return -1;
	}

	nsin.sin_family = AF_INET;
	nsin.sin_port = GetSockPort();
	nsin.sin_addr.s_addr = SocksHost;

	if (connect(sock, &nsin, sizeof(struct sockaddr_in)) < 0) {
		errno = ETIMEDOUT;
		return -1;
	}

	dst.version = SOCKS_VERSION;
	dst.cmd = SOCKS_CONNECT;
	dst.port = sin->sin_port;
	dst.host = sin->sin_addr.s_addr;

	SendDst(sock, &dst);

#if SOCKS_VERSION > 2
	if ((pw = getpwuid(getuid())) == NULL) {
		char	c = '\0';
		write(sock, &c, 1);
	} else {
		write(sock, pw->pw_name, strlen(pw->pw_name) + 1);
	}
#endif

	GetDst(sock, &dst);

	if (dst.cmd == SOCKS_FAIL) {
		errno = ETIMEDOUT;
		return -1;
	}

	return 0;
}

/*
**  Set up a bind for a remote host, add fill 'cursin' in with the
**   remote server information.
*/
Rbind(sock, sin, size, remhost)
int			sock;
struct sockaddr_in	*sin;
int			size;
unsigned long		remhost;
{
	struct sockaddr_in	nsin;
	Socks_t			dst;
	struct passwd		*pw;

	nsin.sin_family = AF_INET;
	nsin.sin_port = GetSockPort();
	nsin.sin_addr.s_addr = SocksHost;

	if (connect(sock, &nsin, sizeof(struct sockaddr_in)) < 0)
		return -1;

	dst.version = SOCKS_VERSION;
	dst.cmd     = SOCKS_BIND;
	dst.port    = 0;
	dst.host    = remhost;

	SendDst(sock, &dst);

#if SOCKS_VERSION > 2
	if ((pw = getpwuid(getuid())) == NULL) {
		char	c = '\0';
		write(sock, &c, 1);
	} else {
		write(sock, pw->pw_name, strlen(pw->pw_name) + 1);
	}
#endif

	GetDst(sock, &dst);

	cursin.sin_family = AF_INET;
	cursin.sin_port = dst.port;
	cursin.sin_addr.s_addr = SocksHost;

	return 0;
}

/*
**  Stub routine since the listen will have alread succeded on the
**   server.
*/
Rlisten(s, n)
int	s, n;
{
	return 0;
}

/*
**  Well we know where we got a connection from.
*/
Rgetsockname(sock, sin, size)
int			sock;
struct sockaddr_in	*sin;
int			*size;
{
	*size = sizeof(struct sockaddr_in);
	*sin = cursin;

	return 0;
}

/*
**  Do an accept, which is really a select for some data on
**    the present socket.
*/
Raccept(sock, sin, size)
int			sock;
struct sockaddr_in	*sin;
int			*size;
{
	fd_set		fds;
	Socks_t		dst;

	FD_ZERO(&fds);
	FD_SET(sock, &fds);

	if (select(getdtablesize(), &fds, NULL, NULL, NULL) > 0)
		if (FD_ISSET(sock, &fds)) {
			GetDst(sock, &dst);
			sin->sin_family = AF_INET;
			sin->sin_port = dst.port;
			sin->sin_addr.s_addr = dst.host;

			return dup(sock);
		}
	return -1;
}

SOCKSinit()
{
#ifdef NEED_REMOTE_NAMESERVER
	static char	defaultNS[] = SOCKS_DEFAULT_NS;
#endif
	static char	defaultHOST[] = SOCKS_DEFAULT_HOST;
	char		*cp, *ns;
	struct hostent	*hp;

	res_init();

#ifdef NEED_REMOTE_NAMESERVER
	if (((cp = getenv("SOCKS_NS")) != NULL) ||
	    ((cp = getenv("SOCKS_HOST")) != NULL)) {
		ns = cp;
	} else {
		ns = defaultNS;
	}

	if ((hp = gethostbyname(ns)) == NULL) {
		_res.nsaddr_list[0].sin_addr.s_addr = inet_addr(ns);
	} else {
		bcopy(hp->h_addr_list[0], 
			&_res.nsaddr_list[0].sin_addr, hp->h_length);
	}

	_res.nscount = 1;
#endif

	if ((cp = getenv("SOCKS_HOST")) == NULL) {
		ns = defaultHOST;
	} else {
		ns = cp;
	}

	if ((hp = gethostbyname(ns)) == NULL) {
		SocksHost = inet_addr(ns);
	} else {
		bcopy(hp->h_addr_list[0], &SocksHost, hp->h_length);
	}

	/*
	**  Make sure it is defined before we let things
	**   get to far along.
	*/
	(void) GetSockPort();
}

int GetSockPort()
{
	struct servent	*sp;
	static int	service = -1;

	if (service > 0)
		return service;
	
	if ((sp = getservbyname("socks", "tcp")) != NULL)
		return service = sp->s_port;

#ifdef SOCKS_DEF_PORT
	return service = SOCKS_DEF_PORT;
#else
	fprintf(stderr,"Unknown service socks/tcp\n");
	exit(1);
#endif
}
