本节我们通过代码来实现TCP协议连接时的三次握手过程。首先我们需要再次重温一下TCP数据包的相关结构:
代码实现TCP三次握手:程序实现-LMLPHP
我们们将依照上面结构所示来构建数据包,相关代码如下:

public class TCPProtocolLayer implements IProtocol {
	private static int HEADER_LENGTH = 20;
	private int sequence_number = 2;
	private int acknowledgement_number = 0;
	private static int PSEUDO_HEADER_LENGTH = 12;
	public static byte TCP_PROTOCOL_NUMBER = 6;
	private static int POSITION_FOR_DATA_OFFSET = 12;
	private static int POSITION_FOR_CHECKSUM = 16;
	private static byte MAXIMUN_SEGMENT_SIZE_OPTION_LENGTH = 4;
	private static byte MAXIMUN_SEGMENT_OPTION_KIND = 2;
	private static byte WINDOW_SCALE_OPTION_KIND = 3;
	private static byte WINDOW_SCALE_OPTION_LENGTH = 3;
	private static byte WINDOW_SCALE_SHIFT_BYTES = 6;
	private static byte TCP_URG_BIT = (1 << 5);
	private static byte TCP_ACK_BIT = (1 << 4);
	private static byte TCP_PSH_BIT = (1 << 3);
	private static byte TCP_RST_BIT = (1 << 2);
	private static byte TCP_SYN_BIT = (1 << 1);
	private static byte TCP_FIN_BIT = (1);
	@Override
	public byte[] createHeader(HashMap<String, Object> headerInfo) {
		short data_length = 0;
		byte[] data = null;
		if (headerInfo.get("data") != null) {
			data = (byte[])headerInfo.get("data");
		}
		byte[] header_buf = new byte[HEADER_LENGTH];
		ByteBuffer byteBuffer = ByteBuffer.wrap(header_buf);
		if (headerInfo.get("src_port") == null) {
			return null;
		}
		short srcPort = (short)headerInfo.get("src_port");
		byteBuffer.putShort(srcPort);
		if (headerInfo.get("dest_port") == null) {
			return  null;
		}
		short  destPort = (short)headerInfo.get("dest_port");
		byteBuffer.putShort(destPort);

		//设置初始序列号
		if (headerInfo.get("seq_num") != null) {
			sequence_number = (int)headerInfo.get("seq_num");
		}
		if (headerInfo.get("ack_num") != null) {
			acknowledgement_number = (int)headerInfo.get("ack_num");
		}
		byteBuffer.putInt(sequence_number);
		byteBuffer.putInt(acknowledgement_number);
		short control_bits = 0;
		//设置控制位
		if (headerInfo.get("URG") != null) {
			control_bits |= (1 << 5);
		}
		if (headerInfo.get("ACK") != null) {
			control_bits |= (1 << 4);
	    }
		if (headerInfo.get("PSH") != null) {
			control_bits |= (1 << 3);
		}
		if (headerInfo.get("RST") != null) {
			control_bits |= (1 << 2);
		}
		if (headerInfo.get("SYN") != null) {
			control_bits |= (1 << 1);
		}
		if (headerInfo.get("FIN") != null) {
			control_bits |= (1);
		}
		byteBuffer.putShort(control_bits);
		System.out.println(Integer.toBinaryString(control_bits));

		char window = 65535;
		byteBuffer.putChar(window);
		short check_sum = 0;
		byteBuffer.putShort(check_sum);
		short urgent_pointer = 0;
		byteBuffer.putShort(urgent_pointer);

		byte[] maximun_segment_option = new byte[MAXIMUN_SEGMENT_SIZE_OPTION_LENGTH];
		ByteBuffer maximun_segment_buffer =  ByteBuffer.wrap(maximun_segment_option);
		maximun_segment_buffer.put(MAXIMUN_SEGMENT_OPTION_KIND);
		maximun_segment_buffer.put(MAXIMUN_SEGMENT_SIZE_OPTION_LENGTH);
		short segment_size = 1460;
		maximun_segment_buffer.putShort(segment_size);

		byte[] window_scale_option = new byte[WINDOW_SCALE_OPTION_LENGTH];
		ByteBuffer window_scale_buffer = ByteBuffer.wrap(window_scale_option);
		window_scale_buffer.put(WINDOW_SCALE_OPTION_KIND);
		window_scale_buffer.put(WINDOW_SCALE_OPTION_LENGTH);
		window_scale_buffer.put(WINDOW_SCALE_SHIFT_BYTES);

		byte[] option_end = new byte[1];
		option_end[0] = 0;

		int total_length = data_length + header_buf.length + maximun_segment_option.length + window_scale_option.length + option_end.length;
		//总长度必须是4的倍数,不足的话以0补全
		if (total_length % 4 != 0) {
			total_length = (total_length / 4 + 1) * 4;
		}
		byte[] tcp_buffer = new byte[total_length];
		ByteBuffer buffer = ByteBuffer.wrap(tcp_buffer);
		buffer.put(header_buf);
		buffer.put(maximun_segment_option);
		buffer.put(window_scale_option);
		buffer.put(option_end);
		short data_offset = buffer.getShort(POSITION_FOR_DATA_OFFSET);
		data_offset |= (((total_length / 4) & 0x0F) << 12);
		System.out.println(Integer.toBinaryString(data_offset));
		buffer.putShort(POSITION_FOR_DATA_OFFSET, data_offset);
		check_sum = (short)compute_checksum(headerInfo, buffer);
		buffer.putShort(POSITION_FOR_CHECKSUM, check_sum);
		return buffer.array();
	}

	private long compute_checksum(HashMap<String, Object> headerInfo, ByteBuffer tcp_buffer) {
		byte[] pseudo_header = new byte[PSEUDO_HEADER_LENGTH];
		ByteBuffer pseudo_header_buf = ByteBuffer.wrap(pseudo_header);
		byte[] src_addr = (byte[])headerInfo.get("src_ip");
		byte[] dst_addr = (byte[])headerInfo.get("dest_ip");
		pseudo_header_buf.put(src_addr);
		pseudo_header_buf.put(dst_addr);
		byte reserved = 0;
		pseudo_header_buf.put(reserved);
		pseudo_header_buf.put(TCP_PROTOCOL_NUMBER);
		short tcp_length = (short)tcp_buffer.array().length;
		//将伪包头和tcp包头内容合在一起计算校验值
		byte[] total_buffer = new byte[PSEUDO_HEADER_LENGTH + tcp_buffer.array().length];
		ByteBuffer total_buf = ByteBuffer.wrap(total_buffer);
		total_buf.put(pseudo_header);
		total_buf.put(tcp_buffer.array());
		return Utility.checksum(total_buffer, total_buffer.length);
	}

	@Override
	public HashMap<String, Object> handlePacket(Packet packet) {
		ByteBuffer buffer= ByteBuffer.wrap(packet.header);
		HashMap<String, Object> headerInfo = new HashMap<String, Object>();
		short src_port = buffer.getShort();
		headerInfo.put("src_port", src_port);
		short dst_port = buffer.getShort();
		headerInfo.put("dest_port", dst_port);
		int seq_num = buffer.getInt();
		headerInfo.put("seq_num", seq_num);
		int ack_num = buffer.getInt();
		headerInfo.put("ack_num", ack_num);
		short control_bits = buffer.getShort();
		if ((control_bits & TCP_ACK_BIT) != 0) {
			headerInfo.put("ACK", 1);
		}
		if ((control_bits & TCP_SYN_BIT) != 0) {
			headerInfo.put("SYN", 1);
		}
		if ((control_bits & TCP_FIN_BIT) != 0) {
			headerInfo.put("FIN", 1);
		}
		short win_size = buffer.getShort();
		headerInfo.put("window", win_size);
		//越过校验值
		buffer.getShort();
		short urg_pointer = buffer.getShort();
		headerInfo.put("urg_ptr", urg_pointer);
		return headerInfo;
	}
}

上面代码实现了协议层TCP的封包与解包,在函数createHeader中,我们按照上图结构填写相关包头的字段,在函数handlePacket中,我们根据包头的字段获取相应信息。

在ProtocolManager中转层,我们实现下面代码:

private void handleTCPPacket(Packet packet,  HashMap<String, Object> infoFromUpLayer) {
		IProtocol tcpProtocol = new TCPProtocolLayer();
		HashMap<String, Object> headerInfo = tcpProtocol.handlePacket(packet);
		short dstPort = (short)headerInfo.get("dest_port");
		//根据端口获得应该接收UDP数据包的程序
		IApplication app = ApplicationManager.getInstance().getApplicationByPort(dstPort);
		if (app != null) {
			app.handleData(headerInfo);
		}
	}

一旦程序通过JPCap收到TCP包后,它会让上面实现的TCPProtocolLayer去解析数据包内的各个字段,然后检测数据包对应的端口是否在应用层有对应的接收对象,如果有的话,它就将解析信息转交给应用层的接收对象,接下来我们看应用层的相关实现:

public class TCPThreeHandShakes extends Application{
	private byte[] dest_ip;
	private short dest_port;
	private int ack_num = 0;
	private int seq_num = 0;
    public TCPThreeHandShakes(byte[] server_ip, short server_port) {
    	this.dest_ip = server_ip;
    	this.dest_port = server_port;
    	 //指定一个固定端口,以便抓包调试
		this.port = (short)11940;
    }

   public void beginThreeHandShakes() throws Exception {
	   createAndSendPacket(null, "SYN");
   }

   private void createAndSendPacket(byte[] data, String flags) throws Exception {
	   byte[] tcpHeader = createTCPHeader(null, flags);
	   if (tcpHeader == null) {
			throw new Exception("tcp Header create fail");
		}
	   byte[] ipHeader = createIP4Header(tcpHeader.length);
	   byte[] packet  = new byte[tcpHeader.length + ipHeader.length];
	   ByteBuffer packetBuffer = ByteBuffer.wrap(packet);
	   packetBuffer.put(ipHeader);
	   packetBuffer.put(tcpHeader);
	   sendPacket(packet);
   }

   private void sendPacket(byte[] packet) {
	   try {
		    InetAddress ip = InetAddress.getByName("192.168.2.1");
			ProtocolManager.getInstance().sendData(packet, ip.getAddress());
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
   }

   private byte[] createTCPHeader(byte[] data, String flags) {
	   IProtocol tcpProto = ProtocolManager.getInstance().getProtocol("tcp");
		if (tcpProto == null) {
			return null;
		}
		HashMap<String, Object> headerInfo = new HashMap<String, Object>();
		byte[] src_ip = DataLinkLayer.getInstance().deviceIPAddress();
		headerInfo.put("src_ip", src_ip);
		headerInfo.put("dest_ip", this.dest_ip);
		headerInfo.put("src_port", (short)this.port);
		headerInfo.put("dest_port", this.dest_port);
		headerInfo.put("seq_num", seq_num);
		headerInfo.put("ack_num", ack_num);
		String[] flag_units = flags.split(",");
		for(int i = 0; i < flag_units.length; i++) {
			headerInfo.put(flag_units[i], 1);
		}

		byte[] tcpHeader = tcpProto.createHeader(headerInfo);
		return tcpHeader;
   }

   protected byte[] createIP4Header(int dataLength) {
		IProtocol ip4Proto = ProtocolManager.getInstance().getProtocol("ip");
		if (ip4Proto == null || dataLength <= 0) {
			return null;
		}
		//创建IP包头默认情况下只需要发送数据长度,下层协议号,接收方ip地址
		HashMap<String, Object> headerInfo = new HashMap<String, Object>();
		headerInfo.put("data_length", dataLength);
		ByteBuffer destIP = ByteBuffer.wrap(this.dest_ip);
		headerInfo.put("destination_ip", destIP.getInt());
		byte protocol = TCPProtocolLayer.TCP_PROTOCOL_NUMBER;
		headerInfo.put("protocol", protocol);
		headerInfo.put("identification", (short)this.port);
		byte[] ipHeader = ip4Proto.createHeader(headerInfo);

		return ipHeader;
	}

   @Override
	public void handleData(HashMap<String, Object> headerInfo) {
	   short src_port = (short)headerInfo.get("src_port");
	   System.out.println("receive TCP packet with port:" + src_port);
	   boolean ack =  false, syn = false;
	   if (headerInfo.get("ACK") != null) {
		   System.out.println("it is a ACK packet");
		   ack = true;
	   }
	   if (headerInfo.get("SYN") != null) {
		   System.out.println("it is a SYN packet");
		   syn = true;
	   }
	   if (ack && syn) {
		   int seq_num = (int)headerInfo.get("seq_num");
		   int ack_num = (int)headerInfo.get("ack_num");
		   System.out.println("tcp handshake from othersize with seq_num" + seq_num + " and ack_num: " + ack_num);
		   this.seq_num = ack_num + 1;
		   try {
			createAndSendPacket(null, "ACK");
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	   }

   }
}

应用层对象的主要目标是实现TCP连接的三次握手功能。它首先构造了一个TCP数据包,将SYN控制位打开,然后将数据包发送给目标服务器。然后等待对方回应数据包,一旦本机收到对方回发的ACK数据包后,会将数据包内的相关信息转交给当前应用对象,它解读出对方ACK包中回复的ACK数值后,将该数值加一然后再次构造一个ACK包发送给对方,上面程序运行后通过wireshark抓包可看到如下显示:

代码实现TCP三次握手:程序实现-LMLPHP

由此可见,我们成功的完成了TCP协议连接时的三次握手功能,上图显示中有一个数据包设置了RST标志位,它表示重置连接,这个数据包其实不是我们的应用对象发送,很可能是我们绕过了系统网络层发送数据包,当对方数据包回来时,操作系统的网络层发现接收对象没有在它内部不存在,于是自己构造了一个RST数据包发回给对方。

更详细的讲解和代码调试演示过程,请点击链接

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
代码实现TCP三次握手:程序实现-LMLPHP

新书上架,请诸位朋友多多支持:代码实现TCP三次握手:程序实现-LMLPHP

08-29 04:00