/* socket.cpp
 * Based on James Pee's article 'Guidelines for Wrapping Sockets in Classes'
 * published in the C/C++ Users journal.
 * Copyright (C) 2003 Michel Leunen
 * http://www.leunen.com/
 * michel@leunen.com
 *
 * This software is freeware and is provided at no charge to the user.
 * This software is provided 'as-is', without any express or implied warranty.
 * To the maximum extent permitted by applicable law will the author NOT be held
 * liable for any damages whatsoever arising from the use of this software.
 *
 * Permission is granted to anyone to use this software for any purpose, including
 * commercial applications, and to freely redistribute it, with the following
 * restrictions:
 *
 * 1. If you use this software in a product, an acknowledgment in the product
 *    documentation would be appreciated.
 *
 * 2. If you use this software in a commercial application, please send me an email
 *	  detailing the application.
 *
 * 3. This notice may not be removed or altered from any source distribution.
 *
 * 4. It may be distributed by any means, provided that the original files
 * 		as supplied by the author remain intact and no charge is made other than
 * 		for reasonable distribution costs. If you are a vendor, if you want to
 * 		include this software in a compilation or in a magazine cover disk
 * 		you HAVE TO CONTACT ME to obtain my authorization before distribution.
 */

#include <windows.h>
#include <winsock2.h>
#pragma hdrstop
#include <sstream>

#include "socket.h"

/**
 * Exception implementation
 */

mlLib::socket_exception::socket_exception(const std::string &what)
		:msg_(what)
{
	
}

const char * mlLib::socket_exception::what() const throw()
{
	std::ostringstream ss;
	ss << msg_
		 << " Error: "
		 << WSAGetLastError();
	return ss.str().c_str();
}


/**
 * Base Winsock class implementation
 */

int mlLib::Winsock::refcount_ = 0;

mlLib::Winsock::Winsock()
{
	if(refcount_==0)
	{
		WSADATA WInitData;
	  int WError = WSAStartup(MAKEWORD(2,0),&WInitData);
	  if (WError != 0) throw socket_exception("Unable to load winsock.");
	}
	++refcount_;
}

mlLib::Winsock::~Winsock()
{
	--refcount_;
	if(refcount_ == 0)WSACleanup();
}

/**
 * Utility classes implementation
 */

std::string mlLib::LocalHost::GetHostName()const
{
	char Buffer[64];
  int ret = gethostname(Buffer,sizeof(Buffer));
  if(ret != 0)throw socket_exception("Unable to retrieve local host name");
  else return std::string(Buffer);
}

std::string mlLib::LocalHost::GetIPAddress()const
{
	hostent *he = gethostbyname(GetHostName().c_str());
  if(he == 0)throw socket_exception("Invalid hostname");
  in_addr Address;
  memcpy(&Address,he->h_addr_list[0],he->h_length);
  char Buffer[64];
  strcpy(Buffer,inet_ntoa(Address));
  return std::string(Buffer);
}

mlLib::Host::Host(const std::string& address)
		:Address_(address)
{
	
}

bool mlLib::Host::IsDotted()const
{
	unsigned long RemoteAddress = inet_addr(Address_.c_str());
  if(RemoteAddress == INADDR_NONE)return false; //not a dotted address
	else return true;
}

std::string mlLib::Host::GetIPAddress()const
{
	if(!IsDotted())
  {
    hostent *he = gethostbyname(Address_.c_str());
    if(he == 0)throw socket_exception("Invalid hostname");
    in_addr Address;
    memcpy(&Address,he->h_addr_list[0],he->h_length);
    char Buffer[64];
    strcpy(Buffer,inet_ntoa(Address));
    return std::string(Buffer);
  }
  else throw socket_exception("Not a valid hostname");
}

std::string mlLib::Host::GetHostName()const
{
	if(IsDotted())
  {
  	unsigned long RemoteAddress = inet_addr(Address_.c_str());
    hostent *he = gethostbyaddr((char*)&RemoteAddress,4,PF_INET);
    if(he==0)throw socket_exception("Invalid IP address");
    return he->h_name;
  }
  else throw socket_exception("Not a dotted address");
}

/**
 * Socket class implementation
 */

mlLib::Socket::Socket()
    : Initialized(false),
      Domain(AF_INET),
      SocketDescriptor(SOCKET_ERROR),
      ClientIP("")
{

}

mlLib::Socket::Socket(int domain,int type,int protocol)
		: Initialized(false),
      Domain(domain)
{
  SocketDescriptor = socket(domain,type,protocol);
  Initialized = (SocketDescriptor != INVALID_SOCKET);
}

mlLib::Socket::Socket(int fd,const std::string& clientip)
		: Initialized(fd != INVALID_SOCKET?true:false),
		  SocketDescriptor(fd),
			ClientIP(clientip)
{

}

/**
 * Copy constructor and assignment operator transfer ownership
 * of the socket to the new Socket object.
 */

mlLib::Socket::Socket(Socket& rhs)
    : Initialized(rhs.Initialized),
      Domain(rhs.Domain),
      SocketDescriptor(rhs.ReleaseOwnership()),
      ClientIP(rhs.ClientIP)
{

}                    

mlLib::Socket& mlLib::Socket::operator=(mlLib::Socket& rhs)
{
  Initialized = rhs.Initialized;
  Domain = rhs.Domain;
  ClientIP = rhs.ClientIP;
  int tmp = rhs.ReleaseOwnership();
  if(SocketDescriptor != tmp)Close();
  SocketDescriptor = tmp;
  return *this;
}

mlLib::Socket::~Socket()
{
  if(Initialized)Close();
}

int mlLib::Socket::ReleaseOwnership()
{
  int tmp = SocketDescriptor;
  SocketDescriptor = INVALID_SOCKET;
  Initialized = false;
  return tmp;
}

bool mlLib::Socket::GetAddress(const std::string& host,unsigned short port,
    sockaddr_in* in)
{
  memset(in,0,sizeof(sockaddr_in));
  in->sin_family = Domain;
  in->sin_port = htons(port);
  if(isdigit(host[0]))
  {
    in->sin_addr.s_addr = inet_addr(host.c_str());
  }
  else
  {
    hostent* hostStruct;
    in_addr* hostNode;
    hostStruct = gethostbyname(host.c_str());
    if(hostStruct)
    {
      hostNode = (in_addr*)hostStruct->h_addr;
      in->sin_addr.s_addr = hostNode->s_addr;
    }
    else return false;
  }
  return true;
}

void mlLib::Socket::Connect(const std::string& hostname,u_short port)
{
  if(!Initialized)
  {
    throw socket_exception("Socket not initialized");
  }
  sockaddr_in servaddr;
  if(!GetAddress(hostname,port,&servaddr))
      throw socket_exception("Unable to resolve host name");
  if(connect(SocketDescriptor,(sockaddr*)&servaddr,
  		sizeof(servaddr)) < 0)
  {
    throw socket_exception("Unable to connect to host");
  }
}

void mlLib::Socket::Bind(u_short port)
{
  if(!Initialized)
  {
    throw socket_exception("Socket not initialized");
  }
  sockaddr_in servaddr;
  memset(&servaddr,0,sizeof(servaddr));
  servaddr.sin_family = Domain;
  servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
  servaddr.sin_port = htons(port);
  if(bind(SocketDescriptor,(sockaddr*)&servaddr,sizeof(servaddr)) == -1)
  {
    throw socket_exception("Unable to bind to specified port");
  }
}

void mlLib::Socket::Listen(int connections)
{
  if(!Initialized)
  {
    throw socket_exception("Socket not initialized");
  }
  listen(SocketDescriptor,connections);
}

mlLib::Socket mlLib::Socket::Accept()
{
  if(!Initialized)
  {
    throw socket_exception("Socket not initialized");
  }
  sockaddr_in clientaddr;
  int clientlen=sizeof(clientaddr);
  int fd = accept(SocketDescriptor,(sockaddr*)&clientaddr,&clientlen );
  if(fd == INVALID_SOCKET)
      throw socket_exception("Invalid socket error");
  std::string clientip = inet_ntoa(*(in_addr*)&clientaddr.sin_addr.s_addr );
  Socket client(fd,clientip);
  return client;
}

void mlLib::Socket::Close()
{
  shutdown(SocketDescriptor,SD_BOTH);
  closesocket(SocketDescriptor);
  SocketDescriptor = INVALID_SOCKET;
  Initialized=false;
}

unsigned long mlLib::Socket::BytesAvailable()
{
 	if(!Initialized)
  {
    throw socket_exception("Socket not initialized");
  }
  unsigned long BytesAvailable;
  ioctlsocket(SocketDescriptor,FIONREAD,&BytesAvailable);
  return BytesAvailable;
}

bool mlLib::Socket::Readable()
{
  fd_set rset;
  FD_ZERO(&rset);
  FD_SET(SocketDescriptor,&rset);
  if(select(0,&rset,NULL,NULL,NULL) == SOCKET_ERROR)
  {
    throw socket_exception("Error reading through socket");
  }
  return FD_ISSET(SocketDescriptor,&rset);
}

/**
 * Read maximum size bytes. The function returns the number of bytes
 * read. If the returned value is 0, the connection was closed by peer.
 * But be aware that all characters must be read before the connection is closed.
 * That means that the data buffer must be large enough to read all the
 * characters. The following read will yield 0, indicating a connection closed.
 */
int mlLib::Socket::Read(void* data,int size)
{
  int bytesreceived;
  fd_set rset;          
  FD_ZERO(&rset);
  FD_SET(SocketDescriptor,&rset);
  if(select(0,&rset,NULL,NULL,NULL) == SOCKET_ERROR)
  {
    throw socket_exception("Error reading through socket");
  }
  bytesreceived = recv(SocketDescriptor,(char*)data,size,0);
  if(bytesreceived == SOCKET_ERROR)
      throw socket_exception("Error reading through socket");
  return bytesreceived;
}

void mlLib::Socket::Write(const void* data,int size)
{
  char* ptr = (char*)data;
  int nleft = size;
  while(nleft > 0)
  {
    int nwritten = send(SocketDescriptor,ptr,nleft,0);
    nleft -= nwritten;
    ptr += nwritten;
  }
}

std::string mlLib::Socket::ReadLine()
{
  std::string ret;
  char rec = 0;
  while(rec != '\n')
  {
    // if Read() returns 0, the connection was closed by peer.
    if(Read((void*)&rec,1) == 0)return "";
    ret += rec;
  }
  return ret;
}

void mlLib::Socket::WriteString(const std::string& data)
{
  Write((void*)data.c_str(),data.length());
}

int mlLib::Socket::RecvFrom(int desc,void* data,int size,sockaddr* source)
{
  int bytesreceived;
  fd_set rset;
  FD_ZERO(&rset);
  FD_SET(SocketDescriptor,&rset);
  if(select(0,&rset,NULL,NULL,NULL) <= 0)
  {
    throw socket_exception("Error reading through socket");
  }
  int sourcelen=sizeof(sockaddr);
  bytesreceived=recvfrom(desc,(char*)data,size,0,source,&sourcelen);
  if(bytesreceived == SOCKET_ERROR)
      throw socket_exception("Error reading through socket");
  return bytesreceived;
}

void mlLib::Socket::SendTo(int desc,const void* data,int size,sockaddr* target)
{
  char* ptr = (char*)data;
  int nleft = size;
  while(nleft > 0)
  {
    int nwritten = sendto(desc,ptr,nleft,0,target,sizeof(sockaddr));
    nleft -= nwritten;
    ptr += nwritten;
  }
}

/**
 * TCP Client
 */

mlLib::TCPClient::TCPClient(const std::string& host,unsigned short port)
    : Socket(AF_INET,SOCK_STREAM,IPPROTO_TCP)
{
  Connect(host,port);
}

/**
 * TCP Server
 */

mlLib::TCPServer::TCPServer(unsigned short port,int connections)
    : Socket(AF_INET,SOCK_STREAM,IPPROTO_TCP)
{
  Bind(port);
  Listen(connections);
}

// not needed for UDP protocol
mlLib::Socket mlLib::TCPServer::AcceptClient()
{
  return Accept();
}

/**
 * UDP Client
 */

mlLib::UDPClient::UDPClient(const std::string& host,unsigned short port)
    : Socket(AF_INET,SOCK_DGRAM,IPPROTO_UDP)
{
  ResolveAddress(host,port);
}

void mlLib::UDPClient::ResolveAddress(const std::string& host,unsigned short port)
{
  if(!GetAddress(host,port,(sockaddr_in*)&RemoteAddress))
    throw socket_exception("Unable to resolve host name");
}

int mlLib::UDPClient::ReadFrom(void* data,int size)
{
  return RecvFrom(SocketDescriptor,data,size,&RemoteAddress);
}

void mlLib::UDPClient::WriteTo(const void* data,int size)
{
  SendTo(SocketDescriptor,data,size,&RemoteAddress);
}

/**
 * UDP Server
 */

mlLib::UDPServer::UDPServer(unsigned short port)
    : Socket(AF_INET,SOCK_DGRAM,IPPROTO_UDP)
{
  Bind(port);
}

int mlLib::UDPServer::ReadFrom(void* data,int size)
{
  return RecvFrom(SocketDescriptor,data,size,&ClientAddress);
}

void mlLib::UDPServer::WriteTo(const void* data,int size)
{
  SendTo(SocketDescriptor,data,size,&ClientAddress);
}
