113 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			113 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Go
		
	
	
	
package ip
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"net"
 | 
						|
)
 | 
						|
 | 
						|
// Fast lookup table for intersection of a single IP address within a collection of CIDR networks.
 | 
						|
//
 | 
						|
// Supports 4-byte (IPv4) and 16-byte (IPv6) networks.
 | 
						|
//
 | 
						|
// Provides O(1) best-case, O(log(n)) worst-case performance.
 | 
						|
// In practice netmasks included will generally only be of standard lengths:
 | 
						|
// - /8, /16, /24, and /32 for IPv4
 | 
						|
// - /64 and /128 for IPv6.
 | 
						|
// As a result, typical lookup times will lean closer to best-case rather than worst-case even when most of the internet
 | 
						|
// is included.
 | 
						|
type NetSet struct {
 | 
						|
	ip4NetMaps []ipNetMap
 | 
						|
	ip6NetMaps []ipNetMap
 | 
						|
}
 | 
						|
 | 
						|
// Create a new NetSet with all of the provided networks.
 | 
						|
func NewNetSet() *NetSet {
 | 
						|
	return &NetSet{
 | 
						|
		ip4NetMaps: make([]ipNetMap, 0),
 | 
						|
		ip6NetMaps: make([]ipNetMap, 0),
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Check if `ip` is in the set, true if within the set otherwise false.
 | 
						|
func (w *NetSet) Has(ip net.IP) bool {
 | 
						|
	netMaps := w.getNetMaps(ip)
 | 
						|
 | 
						|
	// Check all ipNetMaps for intersection with `ip`.
 | 
						|
	for _, netMap := range *netMaps {
 | 
						|
		if netMap.has(ip) {
 | 
						|
			return true
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
// Add an CIDR network to the set.
 | 
						|
func (w *NetSet) AddIPNet(ipNet net.IPNet) {
 | 
						|
	netMaps := w.getNetMaps(ipNet.IP)
 | 
						|
 | 
						|
	// Determine the size / number of ones in the CIDR network mask.
 | 
						|
	ones, _ := ipNet.Mask.Size()
 | 
						|
 | 
						|
	var netMap *ipNetMap
 | 
						|
 | 
						|
	// Search for the ipNetMap containing networks with the same number of ones.
 | 
						|
	for i := 0; len(*netMaps) > i; i++ {
 | 
						|
		if netMapOnes, _ := (*netMaps)[i].mask.Size(); netMapOnes == ones {
 | 
						|
			netMap = &(*netMaps)[i]
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Create a new ipNetMap if none with this number of ones have been created yet.
 | 
						|
	if netMap == nil {
 | 
						|
		netMap = &ipNetMap{
 | 
						|
			mask: ipNet.Mask,
 | 
						|
			ips:  make(map[string]bool),
 | 
						|
		}
 | 
						|
		*netMaps = append(*netMaps, *netMap)
 | 
						|
		// Recurse once now that there exists an netMap.
 | 
						|
		w.AddIPNet(ipNet)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// Add the IP to the ipNetMap.
 | 
						|
	netMap.ips[ipNet.IP.String()] = true
 | 
						|
}
 | 
						|
 | 
						|
// Get the appropriate array of networks for the given IP version.
 | 
						|
func (w *NetSet) getNetMaps(ip net.IP) (netMaps *[]ipNetMap) {
 | 
						|
	switch {
 | 
						|
	case ip.To4() != nil:
 | 
						|
		netMaps = &w.ip4NetMaps
 | 
						|
	case ip.To16() != nil:
 | 
						|
		netMaps = &w.ip6NetMaps
 | 
						|
	default:
 | 
						|
		panic(fmt.Sprintf("IP (%s) is neither 4-byte nor 16-byte?", ip.String()))
 | 
						|
	}
 | 
						|
 | 
						|
	return netMaps
 | 
						|
}
 | 
						|
 | 
						|
// Hash-set of CIDR networks with the same mask size.
 | 
						|
type ipNetMap struct {
 | 
						|
	mask net.IPMask
 | 
						|
	ips  map[string]bool
 | 
						|
}
 | 
						|
 | 
						|
// Check if the IP is in any of the CIDR networks contained in this map.
 | 
						|
func (m ipNetMap) has(ip net.IP) bool {
 | 
						|
	// Apply the mask to the IP to remove any irrelevant bits in the IP.
 | 
						|
	ipMasked := ip.Mask(m.mask)
 | 
						|
	if ipMasked == nil {
 | 
						|
		panic(fmt.Sprintf(
 | 
						|
			"Mismatch in net.IPMask and net.IP protocol version, cannot apply mask %s to %s",
 | 
						|
			m.mask.String(), ip.String()))
 | 
						|
	}
 | 
						|
 | 
						|
	// Check if the masked IP is the same as any of the networks.
 | 
						|
	if _, ok := m.ips[ipMasked.String()]; ok {
 | 
						|
		return true
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 |