From: Hans-Christoph Steiner Date: Mon, 17 Mar 2025 13:39:04 +0000 (+0100) Subject: New upstream version 0.9.18 X-Git-Tag: upstream/0.9.18^0 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fupstream;p=location%2Fdebian%2Flibloc.git New upstream version 0.9.18 --- diff --git a/.gitignore b/.gitignore index 41e3075..4823554 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ +*~ *.log *.mo *.o *.tar.xz +*.trs .deps/ .libs/ Makefile diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000..f6a07ab --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,523 @@ +pipeline { + agent none + + stages { + /* + Run the build and test suite on various distributions... + */ + stage("Run Tests on Multiple Distributions") { + matrix { + axes { + axis { + name "DISTRO" + values \ + "archlinux:base-devel", \ + "debian:trixie", \ + "debian:bookworm", \ + "fedora:41", \ + "fedora:42", \ + "ubuntu:24.10", \ + "ubuntu:25.04" + } + + axis { + name "COMPILER" + values "gcc", "clang" + } + } + + agent { + docker { + image "${DISTRO}" + + // Run as root inside the containers to install dependencies + args "-u root" + + customWorkspace "${JOB_NAME}/${BUILD_ID}/${env.DISTRO.replace(":", "-")}/${env.COMPILER}" + } + } + + stages { + stage("Install Dependencies") { + steps { + script { + // Arch Linux + if (env.DISTRO.contains("archlinux")) { + installBuildDepsArchLinux(env.DISTRO, env.COMPILER) + + // Fedora, etc. + } else if (env.DISTRO.contains("fedora")) { + installBuildDepsRedHat(env.DISTRO, env.COMPILER) + + // Debian & Ubuntu + } else if (env.DISTRO.contains("debian") || env.DISTRO.contains("ubuntu")) { + installBuildDepsDebian(env.DISTRO, env.COMPILER, "amd64") + } + } + } + } + + stage("Configure") { + steps { + // Run autogen + sh "./autogen.sh" + + // Run ./configure... + sh """ + CC=${env.COMPILER} \ + ./configure \ + --prefix=/usr \ + --enable-debug \ + --enable-lua \ + --enable-perl + """ + } + + post { + failure { + archiveArtifacts artifacts: "config.log", + allowEmptyArchive: true + + echo "config.log has been archived" + } + } + } + + stage("Build") { + steps { + sh "make" + } + } + + stage("Check") { + steps { + script { + sh "make check" + } + } + + post { + always { + // Copy test logs into a special directory + sh """ + mkdir -pv tests/${DISTRO}/${COMPILER} + find src tests -name "*.log" | xargs --no-run-if-empty \ + cp --verbose --parents --target-directory=tests/${DISTRO}/${COMPILER}/ + """ + + // Archive the logs only if the stage fails + archiveArtifacts artifacts: "tests/${DISTRO}/${COMPILER}/**/*" + + echo "The test logs have been archived" + } + } + } + } + + // Cleanup the workspace afterwards + post { + always { + cleanWs() + } + } + } + } + + stage("Coverage Tests") { + parallel { + /* + Run through Clang's Static Analyzer... + */ + stage("Clang Static Analyzer") { + agent { + docker { + image "debian:trixie" + + // Run as root inside the containers to install dependencies + args "-u root" + + customWorkspace "${JOB_NAME}/${BUILD_ID}/clang-static-analyzer" + } + } + + stages { + stage("Install Dependencies") { + steps { + script { + installBuildDepsDebian("trixie", "clang", "amd64") + + // Install Clang Tools + sh "apt-get install -y clang-tools" + } + } + } + + stage("Configure") { + steps { + sh "./autogen.sh" + sh """ + scan-build \ + ./configure \ + --prefix=/usr \ + --enable-debug \ + --enable-lua \ + --enable-perl + """ + } + } + + stage("Build") { + steps { + sh "scan-build -o scan-build-output make -j\$(nproc)" + } + } + + stage("Publish Report") { + steps { + archiveArtifacts artifacts: "scan-build-output/**/*" + } + } + } + + // Cleanup the workspace afterwards + post { + always { + cleanWs() + } + } + } + } + } + + stage("Debian Packages") { + // Only build packages when we are in the master branch + // when { + // expression { + // env.GIT_BRANCH == "origin/master" + // } + // } + + stages { + stage("Build Debian Packages") { + matrix { + axes { + axis { + name "IMAGE" + values "debian:trixie", "debian:bookworm" + } + + axis { + name "ARCH" + values "amd64", "arm64", "armel", "armhf", "i386", "ppc64el" + } + } + + agent { + docker { + image "${IMAGE}" + + // Run as root inside the containers to install dependencies + args "-u root" + + customWorkspace "${JOB_NAME}/${BUILD_ID}/${IMAGE.replace(":", "-")}/${ARCH}" + } + } + + stages { + stage("Setup Build Environment") { + steps { + // Add the architecture + sh "dpkg --add-architecture ${env.ARCH}" + sh "apt-get update" + + // Install required packages + sh """ + apt-get install -y \ + apt-utils \ + build-essential \ + crossbuild-essential-${env.ARCH} \ + devscripts \ + qemu-user-static + """ + } + } + + stage("Install Build Dependencies") { + steps { + // Install all build dependencies + sh "apt-get build-dep -y -a${env.ARCH} ." + } + } + + stage("Tag") { + steps { + sh "dch -m \"Jenkins Build ${BUILD_ID}\" -l .build-${BUILD_ID}." + } + } + + stage("Build") { + environment { + VERSION = "" + } + + steps { + // Create the source tarball from the checked out source + sh """ + version="\$(dpkg-parsechangelog --show-field Version | sed "s/-[^-]*\$//")"; + tar \ + --create \ + --verbose \ + --xz \ + --file="../libloc_\$version.orig.tar.xz" \ + --transform="s|^|libloc-\$version/|" \ + * + """ + + // Build the packages + sh """ + dpkg-buildpackage \ + --host-arch ${env.ARCH} \ + --build=full + """ + } + } + + stage("Create Repository") { + environment { + DISTRO = "${IMAGE.replace("debian:", "")}" + } + + steps { + // Create a repository and generate Packages + sh "mkdir -pv \ + packages/debian/dists/${DISTRO}/main/binary-${ARCH} \ + packages/debian/dists/${DISTRO}/main/source \ + packages/debian/pool/${DISTRO}/main/${ARCH}" + + // Copy all packages + sh "cp -v ../*.deb packages/debian/pool/${DISTRO}/main/${ARCH}" + + // Generate Packages + sh "cd packages/debian && apt-ftparchive packages pool/${DISTRO}/main/${ARCH} \ + > dists/${DISTRO}/main/binary-${ARCH}/Packages" + + // Compress Packages + sh "xz -v9 < packages/debian/dists/${DISTRO}/main/binary-${ARCH}/Packages \ + > packages/debian/dists/${DISTRO}/main/binary-${ARCH}/Packages.xz" + + // Generate Sources + sh "cd packages/debian && apt-ftparchive sources pool/${DISTRO}/main/${ARCH} \ + > dists/${DISTRO}/main/source/Sources" + + // Compress Sources + sh "xz -v9 < packages/debian/dists/${DISTRO}/main/source/Sources \ + > packages/debian/dists/${DISTRO}/main/source/Sources.xz" + + // Generate Contents + sh "cd packages/debian && apt-ftparchive contents pool/${DISTRO}/main/${ARCH} \ + > dists/${DISTRO}/main/Contents-${ARCH}" + + // Compress Contents + sh "xz -v9 < packages/debian/dists/${DISTRO}/main/Contents-${ARCH} \ + > packages/debian/dists/${DISTRO}/main/Contents-${ARCH}.xz" + + // Stash the packages + stash includes: "packages/debian/**/*", name: "${DISTRO}-${ARCH}" + } + } + } + + // Cleanup the workspace afterwards + post { + always { + cleanWs() + } + } + } + } + + stage("Master Debian Repository") { + agent any + + // We don't need to check out the source for this stage + options { + skipDefaultCheckout() + } + + environment { + GNUPGHOME = "${WORKSPACE}/.gnupg" + KRB5CCNAME = "${WORKSPACE}/.krb5cc" + + // Our signing key + GPG_KEY_ID = "E4D20FA6FAA108D54ABDFC6541836ADF9D5E2AD9" + } + + steps { + // Cleanup the workspace + cleanWs() + + // Create the GPG stash directory + sh """ + mkdir -p $GNUPGHOME + chmod 700 $GNUPGHOME + """ + + // Import the GPG key + withCredentials([file(credentialsId: "${env.GPG_KEY_ID}", variable: "GPG_KEY_FILE")]) { + // Jenkins prefers to have single quotes here so that $GPG_KEY_FILE won't be expanded + sh 'gpg --import --batch < $GPG_KEY_FILE' + } + + // Unstash all stashed packages from the matrix build + script { + for (distro in ["trixie", "bookworm"]) { + for (arch in ["amd64", "arm64"]) { + unstash "${distro}-${arch}" + } + + // Create the Release file + sh """ + ( + echo "Origin: Pakfire Repository" + echo "Label: Pakfire Repository" + echo "Suite: stable" + echo "Codename: $distro" + echo "Version: 1.0" + echo "Architectures: amd64 arm64" + echo "Components: main" + echo "Description: Pakfire Jenkins Repository" + + # Do the rest automatically + cd packages/debian && apt-ftparchive release dists/$distro + ) >> packages/debian/dists/$distro/Release + """ + + // Create InRelease + sh """ + gpg --batch \ + --clearsign \ + --local-user ${env.GPG_KEY_ID} \ + --output packages/debian/dists/$distro/InRelease \ + packages/debian/dists/$distro/Release + """ + + // Create Release.gpg + sh """ + gpg --batch \ + --armor \ + --detach-sign \ + --local-user ${env.GPG_KEY_ID} \ + --output packages/debian/dists/$distro/Release.gpg \ + packages/debian/dists/$distro/Release + """ + } + } + + // Export the public key + sh "gpg --batch --export --armor ${env.GPG_KEY_ID} \ + > packages/debian/${env.GPG_KEY_ID}.asc" + + // Remove the GPG key material as soon as possible + sh "rm -rf $GNUPGHOME" + + // Upload everything again + archiveArtifacts artifacts: "packages/debian/**/*" + + // Fetch a Kerberos ticket + withCredentials([file(credentialsId: "jenkins.keytab", variable: "KEYTAB")]) { + sh 'kinit -kV -t $KEYTAB jenkins@IPFIRE.ORG' + } + + // Publish files + sh """ + rsync \ + --verbose \ + --recursive \ + --delete \ + --delete-excluded \ + --delay-updates \ + packages/debian/ \ + pakfire@fs01.haj.ipfire.org:/pub/mirror/packages/debian/libloc + """ + + // Destroy the Kerberos credentials + sh "kdestroy" + } + } + } + } + } +} + +// Installs everything we need on RHEL/Fedora/etc. +def installBuildDepsRedHat(distro, compier) { + // Install basic development tools + sh "dnf group install -y development-tools" + + // Install our own build and runtime dependencies + sh """ + dnf install -y \ + asciidoc \ + autoconf \ + automake \ + intltool \ + libtool \ + pkg-config \ + ${compiler} \ + \ + lua-devel \ + lua-lunit \ + openssl-devel \ + perl-devel \ + "perl(Test::More)" \ + python3-devel \ + systemd-devel + """ +} + +// Installs everything we need on Arch Linux +def installBuildDepsArchLinux(distro, compiler) { + sh "pacman -Syu --noconfirm" + sh """ + pacman -Sy \ + --needed \ + --noconfirm \ + asciidoc \ + autoconf \ + automake \ + intltool \ + libtool \ + pkg-config \ + ${compiler} \ + \ + lua \ + openssl \ + perl \ + python3 \ + systemd + """ +} + +// Installs everything we need on Debian +def installBuildDepsDebian(distro, compiler, arch) { + sh "apt-get update" + sh """ + apt-get install -y \ + --no-install-recommends \ + asciidoc \ + autoconf \ + automake \ + build-essential \ + intltool \ + libtool \ + pkg-config \ + ${compiler} \ + \ + liblua5.4-dev \ + libperl-dev \ + libpython3-dev \ + libssl-dev \ + libsystemd-dev \ + lua-unit + """ +} diff --git a/Makefile.am b/Makefile.am index 89fc8b5..e5b4fc5 100644 --- a/Makefile.am +++ b/Makefile.am @@ -3,6 +3,7 @@ CLEANFILES = INSTALL_DIRS = ACLOCAL_AMFLAGS = -I m4 ${ACLOCAL_FLAGS} AM_MAKEFLAGS = --no-print-directory +check_SCRIPTS = SUBDIRS = . po BINDINGS = @@ -20,11 +21,13 @@ AM_CPPFLAGS = \ -DSYSCONFDIR=\""$(sysconfdir)"\" \ -I${top_srcdir}/src -AM_CFLAGS = ${my_CFLAGS} \ +AM_CFLAGS = \ + $(OUR_CFLAGS) \ -ffunction-sections \ -fdata-sections -AM_LDFLAGS = ${my_LDFLAGS} +AM_LDFLAGS = \ + $(OUR_LDFLAGS) # leaving a space here to work around automake's conditionals ifeq ($(OS),Darwin) @@ -44,6 +47,7 @@ DISTCHECK_CONFIGURE_FLAGS = \ SED_PROCESS = \ $(AM_V_GEN)$(MKDIR_P) $(dir $@) && $(SED) \ + -e 's,@LUA_VERSION\@,$(LUA_VERSION),g' \ -e 's,@VERSION\@,$(VERSION),g' \ -e 's,@prefix\@,$(prefix),g' \ -e 's,@exec_prefix\@,$(exec_prefix),g' \ @@ -56,6 +60,7 @@ SED_PROCESS = \ cron_dailydir = $(sysconfdir)/cron.daily databasedir = $(localstatedir)/lib/location pkgconfigdir = $(libdir)/pkgconfig +systemdsystemunitdir = $(prefix)/lib/systemd/system # Overwrite Python path pkgpythondir = $(pythondir)/location @@ -87,6 +92,7 @@ po/POTFILES.in: Makefile sed -e "s@$(abs_srcdir)/@@g" | LC_ALL=C sort > $@ EXTRA_DIST += \ + README.md \ examples/private-key.pem \ examples/public-key.pem \ examples/python/create-database.py \ @@ -104,6 +110,7 @@ pkginclude_HEADERS = \ src/libloc/format.h \ src/libloc/network.h \ src/libloc/network-list.h \ + src/libloc/network-tree.h \ src/libloc/private.h \ src/libloc/stringpool.h \ src/libloc/resolv.h \ @@ -122,6 +129,7 @@ src_libloc_la_SOURCES = \ src/database.c \ src/network.c \ src/network-list.c \ + src/network-tree.c \ src/resolv.c \ src/stringpool.c \ src/writer.c @@ -131,16 +139,16 @@ EXTRA_DIST += src/libloc.sym src_libloc_la_CFLAGS = \ $(AM_CFLAGS) \ -DLIBLOC_PRIVATE \ - -fvisibility=hidden + -fvisibility=hidden \ + $(OPENSSL_CFLAGS) src_libloc_la_LDFLAGS = \ $(AM_LDFLAGS) \ - -version-info $(LIBLOC_CURRENT):$(LIBLOC_REVISION):$(LIBLOC_AGE) + -version-info $(LIBLOC_CURRENT):$(LIBLOC_REVISION):$(LIBLOC_AGE) \ + $(OPENSSL_LDFLAGS) if HAVE_LD_VERSION_SCRIPT src_libloc_la_LDFLAGS += -Wl,--version-script=$(top_srcdir)/src/libloc.sym -else -src_libloc_la_LDFLAGS += -export-symbols $(top_srcdir)/src/libloc.sym endif src_libloc_la_LIBADD = \ @@ -191,7 +199,6 @@ dist_pkgpython_PYTHON = \ src/python/location/downloader.py \ src/python/location/export.py \ src/python/location/i18n.py \ - src/python/location/importer.py \ src/python/location/logger.py pyexec_LTLIBRARIES = \ @@ -213,7 +220,10 @@ src_python__location_la_SOURCES = \ src_python__location_la_CFLAGS = \ $(AM_CFLAGS) \ - $(PYTHON_CFLAGS) + $(PYTHON_CFLAGS) \ + -Wno-cast-function-type \ + -Wno-redundant-decls \ + -Wno-strict-aliasing src_python__location_la_LDFLAGS = \ $(AM_LDFLAGS) \ @@ -225,6 +235,70 @@ src_python__location_la_LIBADD = \ src/libloc.la \ $(PYTHON_LIBS) +# ------------------------------------------------------------------------------ + +if ENABLE_LUA +lua_LTLIBRARIES = \ + src/lua/location.la + +luadir = $(LUA_INSTALL_CMOD) + +src_lua_location_la_SOURCES = \ + src/lua/as.c \ + src/lua/as.h \ + src/lua/compat.h \ + src/lua/country.c \ + src/lua/country.h \ + src/lua/database.c \ + src/lua/database.h \ + src/lua/location.c \ + src/lua/location.h \ + src/lua/network.c \ + src/lua/network.h + +src_lua_location_la_CFLAGS = \ + $(AM_CFLAGS) \ + $(LUA_CFLAGS) + +src_lua_location_la_LDFLAGS = \ + $(AM_LDFLAGS) \ + $(LUA_LDFLAGS) \ + -shared \ + -module \ + -avoid-version + +src_lua_location_la_LIBADD = \ + src/libloc.la \ + $(LUA_LIBS) +endif + +EXTRA_DIST += \ + src/lua/as.c \ + src/lua/as.h \ + src/lua/country.c \ + src/lua/country.h \ + src/lua/database.c \ + src/lua/database.h \ + src/lua/location.c \ + src/lua/location.h \ + src/lua/network.c \ + src/lua/network.h + +LUA_TESTS = \ + tests/lua/main.lua + +EXTRA_DIST += \ + tests/lua/main.lua.in + +CLEANFILES += \ + tests/lua/main.lua + +tests/lua/main.lua: tests/lua/main.lua.in Makefile + $(SED_PROCESS) + chmod o+x $@ + +# ------------------------------------------------------------------------------ + # Compile & install bindings all-local: $(foreach binding,$(BINDINGS),build-$(binding)) check-local: $(foreach binding,$(BINDINGS),check-$(binding)) @@ -344,11 +418,16 @@ TESTS_LDADD = \ src/libloc-internal.la TESTS_ENVIRONMENT = \ + LD_LIBRARY_PATH="$(abs_builddir)/src/.libs" \ + LUA_CPATH="$(abs_builddir)/src/lua/.libs/?.so;;" \ PYTHONPATH=$(abs_srcdir)/src/python:$(abs_builddir)/src/python/.libs \ - TEST_DATA_DIR="$(abs_top_srcdir)/data" + TEST_DATA_DIR="$(abs_top_srcdir)/data" \ + TEST_DATABASE="$(abs_top_srcdir)/data/database.db" \ + TEST_SIGNING_KEY="$(abs_top_srcdir)/data/signing-key.pem" TESTS = \ $(check_PROGRAMS) \ + $(check_SCRIPTS) \ $(dist_check_SCRIPTS) CLEANFILES += \ @@ -360,9 +439,16 @@ testdata.db: examples/python/create-database.py $(PYTHON) $< $@ dist_check_SCRIPTS = \ + tests/python/country.py \ + tests/python/networks-dedup.py \ tests/python/test-database.py \ tests/python/test-export.py +if ENABLE_LUA_TESTS +check_SCRIPTS += \ + $(LUA_TESTS) +endif + check_PROGRAMS = \ src/test-libloc \ src/test-stringpool \ @@ -539,6 +625,9 @@ man/%.html: man/%.txt man/asciidoc.conf upload-man: $(MANPAGES_HTML) rsync -avHz --delete --progress $(MANPAGES_HTML) ms@fs01.haj.ipfire.org:/pub/man-pages/$(PACKAGE_NAME)/ +EXTRA_DIST += \ + tools/copy.py + EXTRA_DIST += \ debian/build.sh \ debian/changelog \ diff --git a/README.md b/README.md new file mode 100644 index 0000000..36c8144 --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# **_`libloc`_** - IP Address Location + +[Home](https://www.ipfire.org/location) + +`libloc` is a library for fast and efficient IP address location. + +It offers: + +- **The Fastest Lookups**: O(1) lookup time for IP addresses using a binary tree structure. +- **Low Memory Footprint**: The database is packed in a very efficient format. +- **Security**: Integrated signature verification for data integrity. +- **Maintainability**: Automatic updates. +- **Standalone**: No external dependencies, easy to integrate. + +`libloc` is ideal for: + +- Firewalls +- Intrusion Prevention/Detection Systems (IPS/IDS) +- Web Applications +- Network Management Tools + +The publicly available daily updated database stores information about: + +- The entire IPv6 and IPv4 Internet +- Autonomous System Information including names +- Country Codes, Names and Continent Codes + +## Command Line + +`libloc` comes with a command line tool which makes it easy to test the library or +integrate it into your shell scripts. location(8) knows a couple of commands to retrieve +country or Autonomous System of an IP address and can generate lists of networks to be +imported into other software. + +`location (8)` is versatile and very easy to use. + +## Language Bindings + +`libloc` itself is written in C. There are bindings for the following languages available: + +- Python 3 +- Lua +- Perl + +`libloc` comes with native Python bindings which are used by its main command-line tool +location. They are the most advanced bindings as they support reading from the database +as well as writing to it. diff --git a/configure.ac b/configure.ac index faf896b..8fb8e12 100644 --- a/configure.ac +++ b/configure.ac @@ -1,6 +1,6 @@ AC_PREREQ(2.60) AC_INIT([libloc], - [0.9.17], + [0.9.18], [location@lists.ipfire.org], [libloc], [https://location.ipfire.org/]) @@ -125,33 +125,65 @@ AC_CHECK_FUNCS([ \ qsort \ ]) -my_CFLAGS="\ --Wall \ --Wchar-subscripts \ --Wformat-security \ --Wmissing-declarations \ --Wmissing-prototypes \ --Wnested-externs \ --Wpointer-arith \ --Wshadow \ --Wsign-compare \ --Wstrict-prototypes \ --Wtype-limits \ -" -AC_SUBST([my_CFLAGS]) -AC_SUBST([my_LDFLAGS]) +CC_CHECK_FLAGS_APPEND([with_cflags], [CFLAGS], [\ + -std=gnu17 \ + -Wall \ + -Wextra \ + -Warray-bounds=2 \ + -Wdate-time \ + -Wendif-labels \ + -Werror=format=2 \ + -Werror=format-signedness \ + -Werror=implicit-function-declaration \ + -Werror=implicit-int \ + -Werror=incompatible-pointer-types \ + -Werror=int-conversion \ + -Werror=missing-declarations \ + -Werror=missing-prototypes \ + -Werror=overflow \ + -Werror=override-init \ + -Werror=return-type \ + -Werror=shift-count-overflow \ + -Werror=shift-overflow=2 \ + -Werror=strict-flex-arrays \ + -Werror=undef \ + -Wfloat-equal \ + -Wimplicit-fallthrough=5 \ + -Winit-self \ + -Wlogical-op \ + -Wmissing-include-dirs \ + -Wmissing-noreturn \ + -Wnested-externs \ + -Wold-style-definition \ + -Wpointer-arith \ + -Wredundant-decls \ + -Wshadow \ + -Wstrict-aliasing=2 \ + -Wstrict-prototypes \ + -Wsuggest-attribute=noreturn \ + -Wunused-function \ + -Wwrite-strings \ + -Wzero-length-bounds \ + -Wno-unused-parameter \ + -Wno-missing-field-initializers \ + -fdiagnostics-show-option \ + -fno-common \ +]) + +# Enable -fno-semantic-interposition (if available) +CC_CHECK_FLAGS_APPEND([with_cflags], [CFLAGS], [-fno-semantic-interposition]) +CC_CHECK_FLAGS_APPEND([with_ldflags], [LDFLAGS], [-fno-semantic-interposition]) # Enable -fanalyzer if requested AC_ARG_ENABLE([analyzer], AS_HELP_STRING([--enable-analyzer], [enable static analyzer (-fanalyzer) @<:@default=disabled@:>@]), [], [enable_analyzer=no]) AS_IF([test "x$enable_analyzer" = "xyes"], - CC_CHECK_FLAGS_APPEND([my_CFLAGS], [CFLAGS], [-fanalyzer]) + CC_CHECK_FLAGS_APPEND([with_cflags], [CFLAGS], [-fanalyzer]) ) -# Enable -fno-semantic-interposition (if available) -CC_CHECK_FLAGS_APPEND([my_CFLAGS], [CFLAGS], [-fno-semantic-interposition]) -CC_CHECK_FLAGS_APPEND([my_LDFLAGS], [LDFLAGS], [-fno-semantic-interposition]) +AC_SUBST([OUR_CFLAGS], $with_cflags) +AC_SUBST([OUR_LDFLAGS], $with_ldflags) # ------------------------------------------------------------------------------ @@ -178,34 +210,57 @@ AS_IF([test "x$with_systemd" != "xno"], [have_systemd=no] ) -AS_IF([test "x$have_systemd" = "xyes"], - [AC_MSG_CHECKING([for systemd system unit directory]) - AC_ARG_WITH([systemdsystemunitdir], - AS_HELP_STRING([--with-systemdsystemunitdir=DIR], [Directory for systemd service files]), - [], [with_systemdsystemunitdir=$($PKG_CONFIG --variable=systemdsystemunitdir systemd)] - ) - - AC_SUBST([systemdsystemunitdir], [$with_systemdsystemunitdir]) - - if test -n "$systemdsystemunitdir" -a "x$systemdsystemunitdir" != xno; then - AC_MSG_RESULT([$systemdsystemunitdir]) - else - AC_MSG_ERROR([not found (try --with-systemdsystemunitdir)]) - fi - ], - [AS_IF([test "x$with_systemd" = "xyes"], - [AC_MSG_ERROR([Systemd support is enabled but no systemd has been found.]) - ]) -]) - AM_CONDITIONAL(HAVE_SYSTEMD, [test "x$have_systemd" = "xyes"]) # ------------------------------------------------------------------------------ +AC_PATH_PROG(PKG_CONFIG, pkg-config, no) + # Python AM_PATH_PYTHON([3.4]) PKG_CHECK_MODULES([PYTHON], [python-${PYTHON_VERSION}]) +# Lua +AC_ARG_ENABLE(lua, + AS_HELP_STRING([--disable-lua], [do not build the Lua modules]), [], [enable_lua=yes]) + +AS_IF( + [test "x$enable_lua" = "xyes"], [ + for lua in lua lua5.4 lua5.3 lua5.2 lua5.1; do + PKG_CHECK_MODULES([LUA], [${lua}], [break], [true]) + done + + LUA_VERSION=$($PKG_CONFIG --variable=major_version ${lua}) + if test -z "${LUA_VERSION}"; then + LUA_VERSION=$($PKG_CONFIG --variable=V ${lua}) + fi + + # Fail if we could not find anything to link against + if test "x${LUA_VERSION}" = "x"; then + AC_MSG_ERROR([Could not find Lua]) + fi + + AC_SUBST(LUA_VERSION) + + LUA_INSTALL_LMOD=$($PKG_CONFIG --define-variable=prefix=${prefix} --variable=INSTALL_LMOD ${lua}) + if test -z "${LUA_INSTALL_LMOD}"; then + LUA_INSTALL_LMOD="${datadir}/lua/${LUA_VERSION}" + fi + AC_SUBST(LUA_INSTALL_LMOD) + + LUA_INSTALL_CMOD=$($PKG_CONFIG --define-variable=prefix=${prefix} --variable=INSTALL_CMOD ${lua}) + if test -z "${LUA_INSTALL_CMOD}"; then + LUA_INSTALL_CMOD="${libdir}/lua/${LUA_VERSION}" + fi + AC_SUBST(LUA_INSTALL_CMOD) + + AX_PROG_LUA_MODULES([luaunit], [enable_lua_tests=yes], [AC_MSG_WARN([luaunit is missing, won't run tests])]) + ], +) + +AM_CONDITIONAL(ENABLE_LUA, [test "x$enable_lua" = "xyes"]) +AM_CONDITIONAL(ENABLE_LUA_TESTS, [test "x$enable_lua_tests" = "xyes"]) + # Perl AC_PATH_PROG(PERL, perl, no) AC_SUBST(PERL) @@ -236,10 +291,7 @@ RESOLV_LIBS="${LIBS}" AC_SUBST(RESOLV_LIBS) dnl Checking for OpenSSL -LIBS= -AC_CHECK_LIB(crypto, EVP_EncryptInit,, AC_MSG_ERROR([libcrypto has not been found])) -OPENSSL_LIBS="${LIBS}" -AC_SUBST(OPENSSL_LIBS) +PKG_CHECK_MODULES([OPENSSL], [openssl]) AC_CONFIG_HEADERS(config.h) AC_CONFIG_FILES([ @@ -267,6 +319,9 @@ AC_MSG_RESULT([ bash-completion: ${enable_bash_completion} Bindings: + Lua: ${enable_lua} + Lua shared path: ${LUA_INSTALL_LMOD} + Lua module path: ${LUA_INSTALL_CMOD} Perl: ${enable_perl} Perl module path: ${PERL_MODPATH} Perl manual path: ${PERL_MANPATH} diff --git a/data/database.db b/data/database.db index 86e7e42..847fca5 100644 Binary files a/data/database.db and b/data/database.db differ diff --git a/debian/changelog b/debian/changelog index de26894..a31c4d5 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,208 @@ +libloc (0.9.18-1) unstable; urgency=medium + + [ Michael Tremer ] + * python: Make AS objects hashable + * network: Fix bit length check when merging networks + * network: Fix deduplication not looking far enough + * network-tree: Split into a separate file + * tools: Import the copy script + * network-tree: Run deduplication once for each family + * network-tree: Use the raw prefix to place networks onto the tree + * network: Tidy up some code + * tests: Add a simple test for deduplication + * python: Make Country hashable + * python: Fix rich comparison function for Country + * tests: Add some tests for the Country object + * country: Return NULL/None for unset attributes + * python: Fix rich comparison for AS objects + * python: Implement rich comparison for Network objects + * tests: Build out more dedup tests + * tests: Show even large diffs in dedup tests + * tests: Add more dedup/merge tests + * tests: Make the new tests executable + * importer: Fix Python syntax error + * importer: Refactor ARIN importer + * importer: Drop previously imported AS names, too + * importer: Tidy up SQL queries + * configure: Require Lua + * lua: Create scaffolding for a module + * .gitignore: Ignore some more temporary files + * lua: Add version() function + * lua: Initialize location context + * lua: Add a Database object with a dummy lookup function + * lua: Add a Network object + * lua: Actually return a network after lookup() + * lua: database: Rename the __gc method for consistency + * lua: database: Add __close method + * lua: network: Add a __close method + * lua: Fix calling methods that belong to an object + * lua: Check if we can read the network's properties + * lua: Force garbage collection to test those methods + * lua: Create Country objects + * lua: Force testing garbage collection for countries + * lua: Don't try to free memory that was allocated by Lua + * lua: country: Export remaining properties + * lua: Add AS object + * lua: database: Implementing fetching AS objects + * lua: database: Implement fetching countries + * lua: database: Export description/license/vendor + * lua: database: Implement verify() + * configure: Check for luaunit + * lua: Export default database path + * lua: Export network flags + * lua: network: Implement checking flags + * configure: Don't check for Lua when --disable-lua is passed + * importer: Pass downloader to the ARIN import + * importer: Create an extra table for feeds + * importer: Import the Spamhaus ASN feed into the new feeds table + * importer: Create a feeds table for networks, too + * importer: Don't import any AS names from Spamhaus + * importer: Import Spamhaus DROP feeds into the new feeds table + * importer: Import AWS IP feed into the new feeds table + * importer: Fix typo in SQL query + * importer: Reformat the large SQL query + * importer: Create a new subcommand to import feeds + * importer: Refactor feed parsing + * importer: Simplify fetching countries + * importer: Reformat AWS dictionary + * importer: Completely rewrite the AWS parser + * importer: Add the option to only update one feed + * importer: Refactor parsing Spamhaus DROP feeds + * importer: Refactor parsing Spamhaus ASNDROP + * importer: Drop source field from overrides table + * importer: Drop any data from feeds we no longer support + * importer: Use the downloader to import Geofeeds + * importer: Remove superfluous function call + * importer: Fail if no countries have been imported, yet + * importer: Ignore certain country codes + * importer: Make translating country codes more extensible + * importer: Return known countries as a set() + * importer: When storing country codes, make the code more straight- + forward + * importer: Skip any countries that we don't know + * importer: Change country code logic + * importer: Improve check for network objects + * importer: Improve checks for unspecified networks + * importer: Also import networks that are smaller than /48 or /24 + * importer: Add option to only import specific RIRs + * importer: Create a better structure to import RIRs + * importer: Drop the geofeed sources when updating RIR data + * importer: No longer import Geofeeds concurrently + * database: Migrate to psycopg3 + * importer: Move the split functions into the main importer + * importer: Merge the downloader into our main downloader + * Update translations + * Add README.md + * tree: Fix memory leak in dedup code + * tree: Be smarter when removing networks from the stack + * tree: Don't check if we need to fill the stack + * importer: Add structure to add Geofeed overrides + * importer: Check imported Geofeed override URLs + * importer: Ignore comments in Geofeeds + * database: Create a connection pool for async operation + * importer: Wrap everything into asyncio + * importer: Timeout if fetching a Geofeed takes longer than 5 seconds + * importer: Use database pipelining when parsing feeds + * importer: Fix incorrect variable name + * importer: Fix another incorrect variable name + * importer: Skip ASN lines in extended format + * importer: Fix another variable error in CSV parser + * importer: Convert the file handle to text before passing to the CSV + parser + * importer: Remove a debugging line + * importer: Import Geofeed overrides with other Geofeeds + * importer: Unify the way we check Geofeed URLs + * importer: Currently update the source when encountering a conflict + * importer: Allow storing multiple Geofeeds for the same network + * importer: Convert networks back to string + * importer: Remove more traces of the Geofeed overrides table + * importer: ANALYZE all tables before we are running the export + * importer: Replace all GIST indexes with SP-GIST + * importer: Make the export 200x faster + * importer: Drop any indexes we no longer need + * importer: Drop even more indexes + * importer: Permit Geofeeds for everything instead of ignoring + * address: Fix bit length calculation + * network: Fix handling bit length on merge + * tests: Add tests for #13236 + * lua: Add compatibility function to compile with Lua >= 5.1 + * lua: Don't raise an error if a network cannot be found + * tests: Fix bit length tests + * lua: Create a simple iterator for all networks + * lua: Cleanup any database iterators + * address: Add functions to access a specific byte/nibble in an + address + * network: Add function to return a reverse pointer for networks + * lua: Add function that returns subnets of a network + * lua: Add method to access database creation time + * importer: Drop EDROP as it has been merged into DROP + * configure: Use pkg-config to find OpenSSL + * writer: Fail if the header could not be written successfully + * writer: Move the cursor back to end when finished writing + * database: Re-open the file handle in r+ mode + * configure: Scan for multiple Lua versions + * lua: Initialize the database object pointer + * Revert "database: Re-open the file handle in r+ mode" + * libloc: Allow passing a pointer to the log callback + * lua: Implement setting a log callback function + * database: Have the lookup function return 0 even if nothing was + found + * lua: Fix raising an exception if no network was found + * importer: Ensure that we set timestamps in the announcements table + * lua: Check if we got returned something on some tests + * lua: Ensure that the testsuite is being executed with the correct + version + * tests: lua: Set a variable to true if we are checking for a boolean + later + * libloc: Refactor summarizing IP address ranges + * perl: Fix a couple of NULL-pointer derefences in the module + * importer: Update a few AWS locations + * importer: Ignore any sudden disconnects when we fetch a Geofeed + * importer: Don't import /4 or /10 networks from the routing table + * po: Update POTFILES.in + * data: Import today's database + * jenkins: Initial import + * configure: Don't automatically detect systemdunitdir + * configure: Check syntax of Lua check + * jenkins: Always Lua extension, too + * jenkins: Don't expect any tests to fail + * jenkins: Remove extra tests we currently don't support + * configure: Fail if Lua was enabled, but not found + * jenkins: Add all supported Debian architectures + * jenkins: Actually install Lua when we want to build against it + * configure: Don't run Lua tests if luaunit is not available + * jenkins: Install lua-unit wherever available + * configure: Make Lua check work on Fedora, too + * tree: Add network to the stack after we have tried again + * tree: Try harder to merge networks + * jenkins: Upload logs from tests in src/, too + * tests: Make bit length mismatch message clearer + * Fix all sorts of string formatting issues + * configure: Enable more compiler warnings + * address: Never pass zero to __builtin_ctz + * tests: Constify path to open in database test + * jenkins: Install perl(Test::More) on Fedora + * log: Perform formatting string sanitation when logging to stderr + * tree: Replace bitfields with flags to mark deleted nodes + * lua: Perform formatting string sanitization + * jenkins: Fix syntax to install perl(Test::More) + * python: Fix type for keyword lists + * python: Fix unintended fallthrough + * Fix string formatting issues on 32 bit systems + * network: Remove dead assignment + * database: Correctly check return value of dup() + + [ Peter Müller ] + * location-importer: Add missing area code for AWS + * location-importer: Fix Spamhaus ASN-DROP parsing + * location-importer: Replace ARIN AS names source with one that offers + human-readable names + + [ Stefan Schantl ] + * perl: Return nothing in case invalid data has been passed to libloc + + -- Michael Tremer Mon, 10 Mar 2025 11:04:29 +0000 + libloc (0.9.17-1) unstable; urgency=medium [ Michael Tremer ] diff --git a/debian/control b/debian/control index 918c0f6..85f5e45 100644 --- a/debian/control +++ b/debian/control @@ -6,8 +6,8 @@ Standards-Version: 4.6.1 Build-Depends: debhelper-compat (= 13), dh-sequence-python3, - asciidoc, intltool, + liblua5.4-dev, libssl-dev, libsystemd-dev, pkg-config, @@ -98,3 +98,15 @@ Priority: optional Section: oldlibs Description: transitional package This is a transitional package. It can safely be removed. + +Package: lua-location +Architecture: any +Section: libs +Depends: + ${misc:Depends}, + ${shlibs:Depends}, +Multi-Arch: foreign +Description: ${source:Synopsis} (Lua bindings) + ${source:Extended-Description} + . + This package provides the Lua bindings for libloc. diff --git a/debian/libloc-dev.install b/debian/libloc-dev.install index 04e85fa..d93d217 100644 --- a/debian/libloc-dev.install +++ b/debian/libloc-dev.install @@ -1,4 +1,3 @@ usr/include/libloc usr/lib/*/libloc.so usr/lib/*/pkgconfig -usr/share/man/man3 diff --git a/debian/location.install b/debian/location.install index deb6c4d..f2e3fe3 100644 --- a/debian/location.install +++ b/debian/location.install @@ -2,5 +2,4 @@ usr/bin usr/share/bash-completion/completions/location var/lib/location/database.db var/lib/location/signing-key.pem -lib/systemd/system -usr/share/man/man1 +usr/lib/systemd/system diff --git a/debian/lua-location.install b/debian/lua-location.install new file mode 100644 index 0000000..4598956 --- /dev/null +++ b/debian/lua-location.install @@ -0,0 +1 @@ +usr/lib/*/lua/* diff --git a/debian/rules b/debian/rules index e5e3f18..dea9931 100755 --- a/debian/rules +++ b/debian/rules @@ -5,7 +5,7 @@ export PYBUILD_SYSTEM=custom export PYBUILD_CLEAN_ARGS=dh_auto_clean export PYBUILD_CONFIGURE_ARGS=intltoolize --force --automake; \ PYTHON={interpreter} dh_auto_configure -- \ - --disable-perl + --disable-man-pages --disable-perl export PYBUILD_BUILD_ARGS=dh_auto_build export PYBUILD_INSTALL_ARGS=dh_auto_install --destdir={destdir}; \ mkdir -p {destdir}/usr/lib/python{version}/dist-packages; \ diff --git a/m4/ax_prog_lua_modules.m4 b/m4/ax_prog_lua_modules.m4 new file mode 100644 index 0000000..8af66fe --- /dev/null +++ b/m4/ax_prog_lua_modules.m4 @@ -0,0 +1,67 @@ +# +# SYNOPSIS +# +# AX_PROG_LUA_MODULES([MODULES], [ACTION-IF-TRUE], [ACTION-IF-FALSE]) +# +# DESCRIPTION +# +# Checks to see if the given Lua modules are available. If true the shell +# commands in ACTION-IF-TRUE are executed. If not the shell commands in +# ACTION-IF-FALSE are run. Note if $LUA is not set (for example by +# calling AC_CHECK_PROG, or AC_PATH_PROG), AC_CHECK_PROG(LUA, lua, lua) +# will be run. +# +# MODULES is a space separated list of module names. To check for a +# minimum version of a module, append the version number to the module +# name, separated by an equals sign. +# +# Example: +# +# AX_PROG_LUA_MODULES(module=1.0.3,, AC_MSG_WARN(Need some Lua modules) +# +# LICENSE +# +# Copyright (c) 2024 Michael Tremer +# +# Copying and distribution of this file, with or without modification, are +# permitted in any medium without royalty provided the copyright notice +# and this notice are preserved. This file is offered as-is, without any +# warranty. + +AU_ALIAS([AC_PROG_LUA_MODULES], [AX_PROG_LUA_MODULES]) +AC_DEFUN([AX_PROG_LUA_MODULES], [dnl + m4_define([ax_lua_modules]) + m4_foreach([ax_lua_module], m4_split(m4_normalize([$1])), [ + m4_append([ax_lua_modules], [']m4_bpatsubst(ax_lua_module,=,[ ])[' ]) + ]) + + # Make sure we have Lua + if test -z "$LUA"; then + AC_CHECK_PROG(LUA, lua, lua) + fi + + if test "x$LUA" != x; then + ax_lua_modules_failed=0 + for ax_lua_module in ax_lua_modules; do + AC_MSG_CHECKING(for Lua module $ax_lua_module) + + # Would be nice to log result here, but can't rely on autoconf internals + $LUA -e "require('$ax_lua_module')" > /dev/null 2>&1 + if test $? -ne 0; then + AC_MSG_RESULT(no); + ax_lua_modules_failed=1 + else + AC_MSG_RESULT(ok); + fi + done + + # Run optional shell commands + if test "$ax_lua_modules_failed" = 0; then + :; $2 + else + :; $3 + fi + else + AC_MSG_WARN(could not find Lua) + fi +])dnl diff --git a/po/POTFILES.in b/po/POTFILES.in index da20831..520a8c1 100644 --- a/po/POTFILES.in +++ b/po/POTFILES.in @@ -5,9 +5,9 @@ src/python/location/database.py src/python/location/downloader.py src/python/location/export.py src/python/location/i18n.py -src/python/location/importer.py src/python/location/logger.py src/scripts/location-importer.in src/scripts/location.in src/systemd/location-update.service.in src/systemd/location-update.timer.in +tools/copy.py diff --git a/po/de.po b/po/de.po index 21b531d..f5e8944 100644 --- a/po/de.po +++ b/po/de.po @@ -7,7 +7,7 @@ msgid "" msgstr "" "Project-Id-Version: libloc 0\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2022-10-29 12:46+0000\n" +"POT-Creation-Date: 2024-03-04 12:21+0000\n" "PO-Revision-Date: 2018-02-01 14:05+0000\n" "Last-Translator: Michael Tremer \n" "Language-Team: German\n" @@ -82,6 +82,9 @@ msgstr "" msgid "Update WHOIS Information" msgstr "" +msgid "Only update these sources" +msgstr "" + msgid "Update BGP Annoucements" msgstr "" @@ -91,6 +94,15 @@ msgstr "" msgid "SERVER" msgstr "" +msgid "Update Geofeeds" +msgstr "" + +msgid "Update Feeds" +msgstr "" + +msgid "Only update these feeds" +msgstr "" + msgid "Update overrides" msgstr "" diff --git a/po/ka.po b/po/ka.po index cba950e..25e60b8 100644 --- a/po/ka.po +++ b/po/ka.po @@ -7,7 +7,7 @@ msgid "" msgstr "" "Project-Id-Version: libloc\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2022-10-29 12:46+0000\n" +"POT-Creation-Date: 2024-03-04 12:21+0000\n" "PO-Revision-Date: 2023-02-22 08:57+0100\n" "Last-Translator: Temuri Doghonadze \n" "Language-Team: Georgian <(nothing)>\n" @@ -83,6 +83,9 @@ msgstr "ბაზის ფორმატის ვერსია" msgid "Update WHOIS Information" msgstr "WHOIS-ის ინფორმაციის განახლება" +msgid "Only update these sources" +msgstr "" + msgid "Update BGP Annoucements" msgstr "BGP-ის ანონსების განახლება" @@ -92,6 +95,17 @@ msgstr "რომელ რაუტის სერვერს დავუკ msgid "SERVER" msgstr "სერვერი" +#, fuzzy +msgid "Update Geofeeds" +msgstr "განახლება გადაფარავს" + +#, fuzzy +msgid "Update Feeds" +msgstr "განახლება გადაფარავს" + +msgid "Only update these feeds" +msgstr "" + msgid "Update overrides" msgstr "განახლება გადაფარავს" diff --git a/src/country.c b/src/country.c index 309cee1..8152a89 100644 --- a/src/country.c +++ b/src/country.c @@ -99,6 +99,9 @@ LOC_EXPORT const char* loc_country_get_code(struct loc_country* country) { } LOC_EXPORT const char* loc_country_get_continent_code(struct loc_country* country) { + if (!*country->continent_code) + return NULL; + return country->continent_code; } diff --git a/src/database.c b/src/database.c index 617b61e..0c86085 100644 --- a/src/database.c +++ b/src/database.c @@ -79,7 +79,7 @@ struct loc_database { // Data mapped into memory char* data; - off_t length; + ssize_t length; struct loc_stringpool* pool; @@ -149,17 +149,17 @@ struct loc_database_enumerator { static inline int __loc_database_check_boundaries(struct loc_database* db, const char* p, const size_t length) { - size_t offset = p - db->data; + ssize_t offset = p - db->data; // Return if everything is within the boundary - if (offset <= db->length - length) + if (offset <= (ssize_t)(db->length - length)) return 1; DEBUG(db->ctx, "Database read check failed at %p for %zu byte(s)\n", p, length); - DEBUG(db->ctx, " p = %p (offset = %jd, length = %zu)\n", p, offset, length); - DEBUG(db->ctx, " data = %p (length = %zu)\n", db->data, db->length); + DEBUG(db->ctx, " p = %p (offset = %zd, length = %zu)\n", p, offset, length); + DEBUG(db->ctx, " data = %p (length = %zd)\n", db->data, db->length); DEBUG(db->ctx, " end = %p\n", db->data + db->length); - DEBUG(db->ctx, " overflow of %zu byte(s)\n", offset + length - db->length); + DEBUG(db->ctx, " overflow of %zd byte(s)\n", (ssize_t)(offset + length - db->length)); // Otherwise raise EFAULT errno = EFAULT; @@ -258,7 +258,7 @@ static int loc_database_mmap(struct loc_database* db) { return 1; } - DEBUG(db->ctx, "Mapped database of %zu byte(s) at %p\n", db->length, db->data); + DEBUG(db->ctx, "Mapped database of %zd byte(s) at %p\n", db->length, db->data); // Tell the system that we expect to read data randomly r = madvise(db->data, db->length, MADV_RANDOM); @@ -403,7 +403,7 @@ static int loc_database_clone_handle(struct loc_database* db, FILE* f) { // Clone file descriptor fd = dup(fd); - if (!fd) { + if (fd < 0) { ERROR(db->ctx, "Could not duplicate file descriptor\n"); return 1; } @@ -618,7 +618,7 @@ LOC_EXPORT int loc_database_verify(struct loc_database* db, FILE* f) { break; default: - ERROR(db->ctx, "Cannot compute hash for database with format %d\n", + ERROR(db->ctx, "Cannot compute hash for database with format %u\n", db->version); r = -EINVAL; goto CLEANUP; @@ -924,11 +924,12 @@ static int __loc_database_lookup(struct loc_database* db, const struct in6_addr* // If this node has a leaf, we will check if it matches if (__loc_database_node_is_leaf(node_v1)) { r = __loc_database_lookup_handle_leaf(db, address, network, network_address, level, node_v1); - if (r <= 0) + if (r < 0) return r; } - return 1; + // Return no error - even if nothing was found + return 0; } LOC_EXPORT int loc_database_lookup(struct loc_database* db, @@ -1235,7 +1236,7 @@ LOC_EXPORT int loc_database_enumerator_next_as( r = loc_as_match_string(*as, enumerator->string); if (r == 1) { - DEBUG(enumerator->ctx, "AS%d (%s) matches %s\n", + DEBUG(enumerator->ctx, "AS%u (%s) matches %s\n", loc_as_get_number(*as), loc_as_get_name(*as), enumerator->string); return 0; @@ -1346,12 +1347,12 @@ static int __loc_database_enumerator_next_network( *network = NULL; } - DEBUG(enumerator->ctx, "Called with a stack of %u nodes\n", + DEBUG(enumerator->ctx, "Called with a stack of %d nodes\n", enumerator->network_stack_depth); // Perform DFS while (enumerator->network_stack_depth > 0) { - DEBUG(enumerator->ctx, "Stack depth: %u\n", enumerator->network_stack_depth); + DEBUG(enumerator->ctx, "Stack depth: %d\n", enumerator->network_stack_depth); // Get object from top of the stack struct loc_node_stack* node = &enumerator->network_stack[enumerator->network_stack_depth]; @@ -1625,7 +1626,6 @@ static int __loc_database_enumerator_next_bogon( return 0; FINISH: - if (!loc_address_all_zeroes(&enumerator->gap6_start)) { r = loc_address_reset_last(&gap_end, AF_INET6); if (r) diff --git a/src/libloc.c b/src/libloc.c index 450c5e6..c2deed7 100644 --- a/src/libloc.c +++ b/src/libloc.c @@ -30,10 +30,15 @@ struct loc_ctx { int refcount; - void (*log_fn)(struct loc_ctx* ctx, - int priority, const char *file, int line, const char *fn, - const char *format, va_list args); - int log_priority; + + // Logging + struct loc_ctx_logging { + int priority; + + // Callback + loc_log_callback callback; + void* data; + } log; }; void loc_log(struct loc_ctx* ctx, @@ -42,11 +47,15 @@ void loc_log(struct loc_ctx* ctx, va_list args; va_start(args, format); - ctx->log_fn(ctx, priority, file, line, fn, format, args); + ctx->log.callback(ctx, ctx->log.data, priority, file, line, fn, format, args); va_end(args); } -static void log_stderr(struct loc_ctx* ctx, +static void log_stderr(struct loc_ctx* ctx, void* data, int priority, + const char* file, int line, const char* fn, const char* format, va_list args) + __attribute__((format(printf, 7, 0))); + +static void log_stderr(struct loc_ctx* ctx, void* data, int priority, const char* file, int line, const char* fn, const char* format, va_list args) { fprintf(stderr, "libloc: %s: ", fn); @@ -79,15 +88,15 @@ LOC_EXPORT int loc_new(struct loc_ctx** ctx) { return 1; c->refcount = 1; - c->log_fn = log_stderr; - c->log_priority = LOG_ERR; + c->log.callback = log_stderr; + c->log.priority = LOG_ERR; const char* env = secure_getenv("LOC_LOG"); if (env) loc_set_log_priority(c, log_priority(env)); INFO(c, "ctx %p created\n", c); - DEBUG(c, "log_priority=%d\n", c->log_priority); + DEBUG(c, "log_priority=%d\n", c->log.priority); *ctx = c; return 0; @@ -112,17 +121,22 @@ LOC_EXPORT struct loc_ctx* loc_unref(struct loc_ctx* ctx) { return NULL; } +LOC_EXPORT void loc_set_log_callback(struct loc_ctx* ctx, loc_log_callback callback, void* data) { + ctx->log.callback = callback; + ctx->log.data = data; +} + LOC_EXPORT void loc_set_log_fn(struct loc_ctx* ctx, void (*log_fn)(struct loc_ctx* ctx, int priority, const char* file, int line, const char* fn, const char* format, va_list args)) { - ctx->log_fn = log_fn; + //ctx->log_fn = log_fn; INFO(ctx, "custom logging function %p registered\n", log_fn); } LOC_EXPORT int loc_get_log_priority(struct loc_ctx* ctx) { - return ctx->log_priority; + return ctx->log.priority; } LOC_EXPORT void loc_set_log_priority(struct loc_ctx* ctx, int priority) { - ctx->log_priority = priority; + ctx->log.priority = priority; } diff --git a/src/libloc.sym b/src/libloc.sym index 29e17f0..b4bce8d 100644 --- a/src/libloc.sym +++ b/src/libloc.sym @@ -1,6 +1,7 @@ LIBLOC_1 { global: loc_ref; + loc_set_log_callback; loc_get_log_priority; loc_set_log_fn; loc_unref; @@ -146,3 +147,10 @@ global: local: *; }; + +LIBLOC_2 { +global: + loc_network_reverse_pointer; +local: + *; +} LIBLOC_1; diff --git a/src/libloc/address.h b/src/libloc/address.h index f4c0ee3..ff6e943 100644 --- a/src/libloc/address.h +++ b/src/libloc/address.h @@ -131,15 +131,69 @@ static inline struct in6_addr loc_prefix_to_bitmask(const unsigned int prefix) { } static inline unsigned int loc_address_bit_length(const struct in6_addr* address) { + unsigned int bitlength = 0; + int trailing_zeroes; + int octet = 0; - foreach_octet_in_address(octet, address) { - if (address->s6_addr[octet]) - return (15 - octet) * 8 + 32 - __builtin_clz(address->s6_addr[octet]); + + // Initialize the bit length + if (IN6_IS_ADDR_V4MAPPED(address)) + bitlength = 32; + else + bitlength = 128; + + // Walk backwards until we find the first one + foreach_octet_in_address_reverse(octet, address) { + // __builtin_ctz does not support zero as input + if (!address->s6_addr[octet]) { + bitlength -= 8; + continue; + } + + // Count all trailing zeroes + trailing_zeroes = __builtin_ctz(address->s6_addr[octet]); + + // We only have one byte + if (trailing_zeroes > 8) + trailing_zeroes = 8; + + // Remove any trailing zeroes from the total length + bitlength -= trailing_zeroes; + + if (trailing_zeroes < 8) + return bitlength; } return 0; } +static inline int loc_address_common_bits(const struct in6_addr* a1, const struct in6_addr* a2) { + int bits = 0; + + // Both must be of the same family + if (IN6_IS_ADDR_V4MAPPED(a1) && !IN6_IS_ADDR_V4MAPPED(a2)) + return -EINVAL; + + else if (!IN6_IS_ADDR_V4MAPPED(a1) && IN6_IS_ADDR_V4MAPPED(a2)) + return -EINVAL; + + // Walk through both addresses octet by octet + for (unsigned int i = (IN6_IS_ADDR_V4MAPPED(a1) ? 12 : 0); i <= 15; i++) { + // Fast path if the entire octet matches + if (a1->s6_addr[i] == a2->s6_addr[i]) { + bits += 8; + + // Otherwise we XOR the octets and count the leading zeroes + // (where both octets have been identical). + } else { + bits += __builtin_clz(a1->s6_addr[i] ^ a2->s6_addr[i]) - 24; + break; + } + } + + return bits; +} + static inline int loc_address_reset(struct in6_addr* address, int family) { switch (family) { case AF_INET6: @@ -265,19 +319,35 @@ static inline void loc_address_decrement(struct in6_addr* address) { } } -static inline int loc_address_count_trailing_zero_bits(const struct in6_addr* address) { - int zeroes = 0; +static inline int loc_address_get_octet(const struct in6_addr* address, const unsigned int i) { + if (IN6_IS_ADDR_V4MAPPED(address)) { + if (i >= 4) + return -ERANGE; - int octet = 0; - foreach_octet_in_address_reverse(octet, address) { - if (address->s6_addr[octet]) { - zeroes += __builtin_ctz(address->s6_addr[octet]); - break; - } else - zeroes += 8; + return (address->s6_addr32[3] >> (i * 8)) & 0xff; + + } else { + if (i >= 32) + return -ERANGE; + + return address->s6_addr[i]; } +} + +static inline int loc_address_get_nibble(const struct in6_addr* address, const unsigned int i) { + int octet = 0; + + // Fetch the octet + octet = loc_address_get_octet(address, i / 2); + if (octet < 0) + return octet; + + // Shift if we want an uneven nibble + if (i % 2 == 0) + octet >>= 4; - return zeroes; + // Return the nibble + return octet & 0x0f; } #endif /* LIBLOC_PRIVATE */ diff --git a/src/libloc/libloc.h b/src/libloc/libloc.h index 938ed75..ea943a9 100644 --- a/src/libloc/libloc.h +++ b/src/libloc/libloc.h @@ -29,6 +29,18 @@ struct loc_ctx *loc_ref(struct loc_ctx* ctx); struct loc_ctx *loc_unref(struct loc_ctx* ctx); int loc_new(struct loc_ctx** ctx); + +typedef void (*loc_log_callback)( + struct loc_ctx* ctx, + void* data, + int priority, + const char* file, + int line, + const char* fn, + const char* format, + va_list args); +void loc_set_log_callback(struct loc_ctx* ctx, loc_log_callback callback, void* data); + void loc_set_log_fn(struct loc_ctx* ctx, void (*log_fn)(struct loc_ctx* ctx, int priority, const char* file, int line, const char* fn, diff --git a/src/libloc/network-tree.h b/src/libloc/network-tree.h new file mode 100644 index 0000000..13052b7 --- /dev/null +++ b/src/libloc/network-tree.h @@ -0,0 +1,65 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2017-2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LIBLOC_NETWORK_TREE_H +#define LIBLOC_NETWORK_TREE_H + +#ifdef LIBLOC_PRIVATE + +#include +#include + +struct loc_network_tree; + +int loc_network_tree_new(struct loc_ctx* ctx, struct loc_network_tree** tree); + +struct loc_network_tree* loc_network_tree_unref(struct loc_network_tree* tree); + +struct loc_network_tree_node* loc_network_tree_get_root(struct loc_network_tree* tree); + +int loc_network_tree_walk(struct loc_network_tree* tree, + int(*filter_callback)(struct loc_network* network, void* data), + int(*callback)(struct loc_network* network, void* data), void* data); + +int loc_network_tree_dump(struct loc_network_tree* tree); + +int loc_network_tree_add_network(struct loc_network_tree* tree, struct loc_network* network); + +size_t loc_network_tree_count_nodes(struct loc_network_tree* tree); + +int loc_network_tree_cleanup(struct loc_network_tree* tree); + +/* + Nodes +*/ + +struct loc_network_tree_node; + +int loc_network_tree_node_new(struct loc_ctx* ctx, struct loc_network_tree_node** node); + +struct loc_network_tree_node* loc_network_tree_node_ref(struct loc_network_tree_node* node); +struct loc_network_tree_node* loc_network_tree_node_unref(struct loc_network_tree_node* node); + +struct loc_network_tree_node* loc_network_tree_node_get( + struct loc_network_tree_node* node, unsigned int index); + +int loc_network_tree_node_is_leaf(struct loc_network_tree_node* node); + +struct loc_network* loc_network_tree_node_get_network(struct loc_network_tree_node* node); + +#endif /* LIBLOC_PRIVATE */ + +#endif /* LIBLOC_NETWORK_TREE_H */ diff --git a/src/libloc/network.h b/src/libloc/network.h index ccbcaa2..6f2dad2 100644 --- a/src/libloc/network.h +++ b/src/libloc/network.h @@ -66,33 +66,18 @@ struct loc_network_list* loc_network_exclude( struct loc_network_list* loc_network_exclude_list( struct loc_network* network, struct loc_network_list* list); +char* loc_network_reverse_pointer(struct loc_network* network, const char* suffix); + #ifdef LIBLOC_PRIVATE +int loc_network_properties_cmp(struct loc_network* self, struct loc_network* other); +unsigned int loc_network_raw_prefix(struct loc_network* network); + int loc_network_to_database_v1(struct loc_network* network, struct loc_database_network_v1* dbobj); int loc_network_new_from_database_v1(struct loc_ctx* ctx, struct loc_network** network, struct in6_addr* address, unsigned int prefix, const struct loc_database_network_v1* dbobj); -struct loc_network_tree; -int loc_network_tree_new(struct loc_ctx* ctx, struct loc_network_tree** tree); -struct loc_network_tree* loc_network_tree_unref(struct loc_network_tree* tree); -struct loc_network_tree_node* loc_network_tree_get_root(struct loc_network_tree* tree); -int loc_network_tree_walk(struct loc_network_tree* tree, - int(*filter_callback)(struct loc_network* network, void* data), - int(*callback)(struct loc_network* network, void* data), void* data); -int loc_network_tree_dump(struct loc_network_tree* tree); -int loc_network_tree_add_network(struct loc_network_tree* tree, struct loc_network* network); -size_t loc_network_tree_count_nodes(struct loc_network_tree* tree); - -struct loc_network_tree_node; -int loc_network_tree_node_new(struct loc_ctx* ctx, struct loc_network_tree_node** node); -struct loc_network_tree_node* loc_network_tree_node_ref(struct loc_network_tree_node* node); -struct loc_network_tree_node* loc_network_tree_node_unref(struct loc_network_tree_node* node); -struct loc_network_tree_node* loc_network_tree_node_get(struct loc_network_tree_node* node, unsigned int index); - -int loc_network_tree_node_is_leaf(struct loc_network_tree_node* node); -struct loc_network* loc_network_tree_node_get_network(struct loc_network_tree_node* node); - -int loc_network_tree_cleanup(struct loc_network_tree* tree); +int loc_network_merge(struct loc_network** n, struct loc_network* n1, struct loc_network* n2); #endif #endif diff --git a/src/lua/as.c b/src/lua/as.c new file mode 100644 index 0000000..558fcbf --- /dev/null +++ b/src/lua/as.c @@ -0,0 +1,136 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include +#include + +#include +#include + +#include + +#include "location.h" +#include "as.h" +#include "compat.h" + +typedef struct as { + struct loc_as* as; +} AS; + +static AS* luaL_checkas(lua_State* L, int i) { + void* userdata = luaL_checkudata(L, i, "location.AS"); + + // Throw an error if the argument doesn't match + luaL_argcheck(L, userdata, i, "AS expected"); + + return (AS*)userdata; +} + +int create_as(lua_State* L, struct loc_as* as) { + // Allocate a new object + AS* self = (AS*)lua_newuserdata(L, sizeof(*self)); + + // Set metatable + luaL_setmetatable(L, "location.AS"); + + // Store country + self->as = loc_as_ref(as); + + return 1; +} + +static int AS_new(lua_State* L) { + struct loc_as* as = NULL; + unsigned int n = 0; + int r; + + // Fetch the number + n = luaL_checknumber(L, 1); + + // Create the AS + r = loc_as_new(ctx, &as, n); + if (r) + return luaL_error(L, "Could not create AS %u: %s\n", n, strerror(errno)); + + // Return the AS + r = create_as(L, as); + loc_as_unref(as); + + return r; +} + +static int AS_gc(lua_State* L) { + AS* self = luaL_checkas(L, 1); + + // Free AS + if (self->as) { + loc_as_unref(self->as); + self->as = NULL; + } + + return 0; +} + +static int AS_tostring(lua_State* L) { + AS* self = luaL_checkas(L, 1); + + uint32_t number = loc_as_get_number(self->as); + const char* name = loc_as_get_name(self->as); + + // Return string + if (name) + lua_pushfstring(L, "AS%d - %s", number, name); + else + lua_pushfstring(L, "AS%d", number); + + return 1; +} + +// Name + +static int AS_get_name(lua_State* L) { + AS* self = luaL_checkas(L, 1); + + // Return the name + lua_pushstring(L, loc_as_get_name(self->as)); + + return 1; +} + +// Number + +static int AS_get_number(lua_State* L) { + AS* self = luaL_checkas(L, 1); + + // Return the number + lua_pushnumber(L, loc_as_get_number(self->as)); + + return 1; +} + +static const struct luaL_Reg AS_functions[] = { + { "new", AS_new }, + { "get_name", AS_get_name }, + { "get_number", AS_get_number }, + { "__gc", AS_gc }, + { "__tostring", AS_tostring }, + { NULL, NULL }, +}; + +int register_as(lua_State* L) { + return register_class(L, "location.AS", AS_functions); +} diff --git a/src/lua/as.h b/src/lua/as.h new file mode 100644 index 0000000..0ea34f9 --- /dev/null +++ b/src/lua/as.h @@ -0,0 +1,29 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LUA_LOCATION_AS_H +#define LUA_LOCATION_AS_H + +#include +#include + +#include + +int register_as(lua_State* L); + +int create_as(lua_State* L, struct loc_as* as); + +#endif /* LUA_LOCATION_AS_H */ diff --git a/src/lua/compat.h b/src/lua/compat.h new file mode 100644 index 0000000..f0172b8 --- /dev/null +++ b/src/lua/compat.h @@ -0,0 +1,56 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LUA_LOCATION_COMPAT_H +#define LUA_LOCATION_COMPAT_H + +#include +#include + +#if LUA_VERSION_RELEASE_NUM < 502 + +static inline void luaL_setmetatable(lua_State* L, const char* name) { + luaL_checkstack(L, 1, "not enough stack slots"); + luaL_getmetatable(L, name); + lua_setmetatable(L, -2); +} + +static inline void luaL_setfuncs(lua_State* L, const luaL_Reg* l, int nup) { + int i; + + luaL_checkstack(L, nup+1, "too many upvalues"); + + for (; l->name != NULL; l++) { + lua_pushstring(L, l->name); + + for (i = 0; i < nup; i++) + lua_pushvalue(L, -(nup + 1)); + + lua_pushcclosure(L, l->func, nup); + lua_settable(L, -(nup + 3)); + } + + lua_pop(L, nup); +} + +static inline void luaL_newlib(lua_State* L, const luaL_Reg* l) { + lua_newtable(L); + luaL_setfuncs(L, l, 0); +} + +#endif /* Lua < 5.2 */ + +#endif /* LUA_LOCATION_COMPAT_H */ diff --git a/src/lua/country.c b/src/lua/country.c new file mode 100644 index 0000000..816bd2f --- /dev/null +++ b/src/lua/country.c @@ -0,0 +1,144 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include +#include + +#include +#include + +#include + +#include "location.h" +#include "compat.h" +#include "country.h" + +typedef struct country { + struct loc_country* country; +} Country; + +static Country* luaL_checkcountry(lua_State* L, int i) { + void* userdata = luaL_checkudata(L, i, "location.Country"); + + // Throw an error if the argument doesn't match + luaL_argcheck(L, userdata, i, "Country expected"); + + return (Country*)userdata; +} + +int create_country(lua_State* L, struct loc_country* country) { + // Allocate a new object + Country* self = (Country*)lua_newuserdata(L, sizeof(*self)); + + // Set metatable + luaL_setmetatable(L, "location.Country"); + + // Store country + self->country = loc_country_ref(country); + + return 1; +} + +static int Country_new(lua_State* L) { + struct loc_country* country = NULL; + const char* code = NULL; + int r; + + // Fetch the code + code = luaL_checkstring(L, 1); + + // Parse the string + r = loc_country_new(ctx, &country, code); + if (r) + return luaL_error(L, "Could not create country %s: %s\n", code, strerror(errno)); + + // Return the country + r = create_country(L, country); + loc_country_unref(country); + + return r; +} + +static int Country_gc(lua_State* L) { + Country* self = luaL_checkcountry(L, 1); + + // Free country + if (self->country) { + loc_country_unref(self->country); + self->country = NULL; + } + + return 0; +} + +static int Country_eq(lua_State* L) { + Country* self = luaL_checkcountry(L, 1); + Country* other = luaL_checkcountry(L, 2); + + // Push comparison result + lua_pushboolean(L, loc_country_cmp(self->country, other->country) == 0); + + return 1; +} + +// Name + +static int Country_get_name(lua_State* L) { + Country* self = luaL_checkcountry(L, 1); + + // Return the code + lua_pushstring(L, loc_country_get_name(self->country)); + + return 1; +} + +// Code + +static int Country_get_code(lua_State* L) { + Country* self = luaL_checkcountry(L, 1); + + // Return the code + lua_pushstring(L, loc_country_get_code(self->country)); + + return 1; +} + +// Continent Code + +static int Country_get_continent_code(lua_State* L) { + Country* self = luaL_checkcountry(L, 1); + + // Return the code + lua_pushstring(L, loc_country_get_continent_code(self->country)); + + return 1; +} + +static const struct luaL_Reg Country_functions[] = { + { "new", Country_new }, + { "get_code", Country_get_code }, + { "get_continent_code", Country_get_continent_code }, + { "get_name", Country_get_name }, + { "__eq", Country_eq }, + { "__gc", Country_gc }, + { "__tostring", Country_get_code }, + { NULL, NULL }, +}; + +int register_country(lua_State* L) { + return register_class(L, "location.Country", Country_functions); +} diff --git a/src/lua/country.h b/src/lua/country.h new file mode 100644 index 0000000..4997d9d --- /dev/null +++ b/src/lua/country.h @@ -0,0 +1,29 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LUA_LOCATION_COUNTRY_H +#define LUA_LOCATION_COUNTRY_H + +#include +#include + +#include + +int register_country(lua_State* L); + +int create_country(lua_State* L, struct loc_country* country); + +#endif /* LUA_LOCATION_COUNTRY_H */ diff --git a/src/lua/database.c b/src/lua/database.c new file mode 100644 index 0000000..dfdb705 --- /dev/null +++ b/src/lua/database.c @@ -0,0 +1,327 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include + +#include +#include + +#include + +#include "location.h" +#include "as.h" +#include "compat.h" +#include "country.h" +#include "database.h" +#include "network.h" + +typedef struct database { + struct loc_database* db; +} Database; + +static Database* luaL_checkdatabase(lua_State* L, int i) { + void* userdata = luaL_checkudata(L, i, "location.Database"); + + // Throw an error if the argument doesn't match + luaL_argcheck(L, userdata, i, "Database expected"); + + return (Database*)userdata; +} + +static int Database_open(lua_State* L) { + const char* path = NULL; + FILE* f = NULL; + int r; + + // Fetch the path + path = luaL_checkstring(L, 1); + + // Allocate a new object + Database* self = (Database*)lua_newuserdata(L, sizeof(*self)); + self->db = NULL; + + // Set metatable + luaL_setmetatable(L, "location.Database"); + + // Open the database file + f = fopen(path, "r"); + if (!f) + return luaL_error(L, "Could not open %s: %s\n", path, strerror(errno)); + + // Open the database + r = loc_database_new(ctx, &self->db, f); + + // Close the file descriptor + fclose(f); + + // Check for errors + if (r) + return luaL_error(L, "Could not open database %s: %s\n", path, strerror(errno)); + + return 1; +} + +static int Database_gc(lua_State* L) { + Database* self = luaL_checkdatabase(L, 1); + + // Free database + if (self->db) { + loc_database_unref(self->db); + self->db = NULL; + } + + return 0; +} + +// Created At + +static int Database_created_at(lua_State* L) { + Database* self = luaL_checkdatabase(L, 1); + + // Fetch the time + time_t created_at = loc_database_created_at(self->db); + + // Push the time onto the stack + lua_pushnumber(L, created_at); + + return 1; +} + +// Description + +static int Database_get_description(lua_State* L) { + Database* self = luaL_checkdatabase(L, 1); + + // Push the description + lua_pushstring(L, loc_database_get_description(self->db)); + + return 1; +} + +// License + +static int Database_get_license(lua_State* L) { + Database* self = luaL_checkdatabase(L, 1); + + // Push the license + lua_pushstring(L, loc_database_get_license(self->db)); + + return 1; +} + +static int Database_get_vendor(lua_State* L) { + Database* self = luaL_checkdatabase(L, 1); + + // Push the vendor + lua_pushstring(L, loc_database_get_vendor(self->db)); + + return 1; +} + +static int Database_get_as(lua_State* L) { + struct loc_as* as = NULL; + int r; + + Database* self = luaL_checkdatabase(L, 1); + + // Fetch number + uint32_t asn = luaL_checknumber(L, 2); + + // Fetch the AS + r = loc_database_get_as(self->db, &as, asn); + if (r) { + lua_pushnil(L); + return 1; + } + + // Create a new AS object + r = create_as(L, as); + loc_as_unref(as); + + return r; +} + +static int Database_get_country(lua_State* L) { + struct loc_country* country = NULL; + int r; + + Database* self = luaL_checkdatabase(L, 1); + + // Fetch code + const char* code = luaL_checkstring(L, 2); + + // Fetch the country + r = loc_database_get_country(self->db, &country, code); + if (r) { + lua_pushnil(L); + return 1; + } + + // Create a new country object + r = create_country(L, country); + loc_country_unref(country); + + return r; +} + +static int Database_lookup(lua_State* L) { + struct loc_network* network = NULL; + int r; + + Database* self = luaL_checkdatabase(L, 1); + + // Require a string + const char* address = luaL_checkstring(L, 2); + + // Perform lookup + r = loc_database_lookup_from_string(self->db, address, &network); + if (r) + return luaL_error(L, "Could not lookup address %s: %s\n", address, strerror(errno)); + + // Nothing found + if (!network) { + lua_pushnil(L); + return 1; + } + + // Create a network object + r = create_network(L, network); + loc_network_unref(network); + + return r; +} + +static int Database_verify(lua_State* L) { + FILE* f = NULL; + int r; + + Database* self = luaL_checkdatabase(L, 1); + + // Fetch path to key + const char* key = luaL_checkstring(L, 2); + + // Open the keyfile + f = fopen(key, "r"); + if (!f) + return luaL_error(L, "Could not open key %s: %s\n", key, strerror(errno)); + + // Verify! + r = loc_database_verify(self->db, f); + fclose(f); + + // Push result onto the stack + lua_pushboolean(L, (r == 0)); + + return 1; +} + +typedef struct enumerator { + struct loc_database_enumerator* e; +} DatabaseEnumerator; + +static DatabaseEnumerator* luaL_checkdatabaseenumerator(lua_State* L, int i) { + void* userdata = luaL_checkudata(L, i, "location.DatabaseEnumerator"); + + // Throw an error if the argument doesn't match + luaL_argcheck(L, userdata, i, "DatabaseEnumerator expected"); + + return (DatabaseEnumerator*)userdata; +} + +static int DatabaseEnumerator_gc(lua_State* L) { + DatabaseEnumerator* self = luaL_checkdatabaseenumerator(L, 1); + + if (self->e) { + loc_database_enumerator_unref(self->e); + self->e = NULL; + } + + return 0; +} + +static int DatabaseEnumerator_next_network(lua_State* L) { + struct loc_network* network = NULL; + int r; + + DatabaseEnumerator* self = luaL_checkdatabaseenumerator(L, lua_upvalueindex(1)); + + // Fetch the next network + r = loc_database_enumerator_next_network(self->e, &network); + if (r) + return luaL_error(L, "Could not fetch network: %s\n", strerror(errno)); + + // If we have received no network, we have reached the end + if (!network) { + lua_pushnil(L); + return 1; + } + + // Create a network object + r = create_network(L, network); + loc_network_unref(network); + + return r; +} + +static int Database_list_networks(lua_State* L) { + DatabaseEnumerator* e = NULL; + int r; + + Database* self = luaL_checkdatabase(L, 1); + + // Allocate a new enumerator + e = lua_newuserdata(L, sizeof(*e)); + luaL_setmetatable(L, "location.DatabaseEnumerator"); + + // Create a new enumerator + r = loc_database_enumerator_new(&e->e, self->db, LOC_DB_ENUMERATE_NETWORKS, 0); + if (r) + return luaL_error(L, "Could not create enumerator: %s\n", strerror(errno)); + + // Push the closure onto the stack + lua_pushcclosure(L, DatabaseEnumerator_next_network, 1); + + return 1; +} + +static const struct luaL_Reg database_functions[] = { + { "created_at", Database_created_at }, + { "get_as", Database_get_as }, + { "get_description", Database_get_description }, + { "get_country", Database_get_country }, + { "get_license", Database_get_license }, + { "get_vendor", Database_get_vendor }, + { "open", Database_open }, + { "lookup", Database_lookup }, + { "list_networks", Database_list_networks }, + { "verify", Database_verify }, + { "__gc", Database_gc }, + { NULL, NULL }, +}; + +int register_database(lua_State* L) { + return register_class(L, "location.Database", database_functions); +} + +static const struct luaL_Reg database_enumerator_functions[] = { + { "__gc", DatabaseEnumerator_gc }, + { NULL, NULL }, +}; + +int register_database_enumerator(lua_State* L) { + return register_class(L, "location.DatabaseEnumerator", database_enumerator_functions); +} diff --git a/src/lua/database.h b/src/lua/database.h new file mode 100644 index 0000000..6a5aa4d --- /dev/null +++ b/src/lua/database.h @@ -0,0 +1,26 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LUA_LOCATION_DATABASE_H +#define LUA_LOCATION_DATABASE_H + +#include +#include + +int register_database(lua_State* L); +int register_database_enumerator(lua_State* L); + +#endif /* LUA_LOCATION_DATABASE_H */ diff --git a/src/lua/location.c b/src/lua/location.c new file mode 100644 index 0000000..0330827 --- /dev/null +++ b/src/lua/location.c @@ -0,0 +1,160 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include "location.h" +#include "as.h" +#include "compat.h" +#include "country.h" +#include "database.h" +#include "network.h" + +struct loc_ctx* ctx = NULL; + +static int log_callback_ref = 0; + +static void log_callback(struct loc_ctx* _ctx, void* data, int priority, const char* file, + int line, const char* fn, const char* format, va_list args) __attribute__((format(printf, 7, 0))); + +static void log_callback(struct loc_ctx* _ctx, void* data, int priority, const char* file, + int line, const char* fn, const char* format, va_list args) { + char* message = NULL; + int r; + + lua_State* L = data; + + // Format the log message + r = vasprintf(&message, format, args); + if (r < 0) + return; + + // Fetch the Lua callback function + lua_rawgeti(L, LUA_REGISTRYINDEX, log_callback_ref); + + // Pass the priority as first argument + lua_pushnumber(L, priority); + + // Pass the message as second argument + lua_pushstring(L, message); + + // Call the function + lua_call(L, 2, 0); + + free(message); +} + +static int set_log_callback(lua_State* L) { + // Check if we have received a function + luaL_checktype(L, 1, LUA_TFUNCTION); + + // Store a reference to the callback function + log_callback_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + // Register our callback helper + if (ctx) + loc_set_log_callback(ctx, log_callback, L); + + return 0; +} + +static int set_log_level(lua_State* L) { + const int level = luaL_checknumber(L, 1); + + // Store the new log level + if (ctx) + loc_set_log_priority(ctx, level); + + return 0; +} + +static int version(lua_State* L) { + lua_pushstring(L, PACKAGE_VERSION); + return 1; +} + +static const struct luaL_Reg location_functions[] = { + { "set_log_callback", set_log_callback }, + { "set_log_level", set_log_level }, + { "version", version }, + { NULL, NULL }, +}; + +int luaopen_location(lua_State* L) { + int r; + + // Initialize the context + r = loc_new(&ctx); + if (r) + return luaL_error(L, + "Could not initialize location context: %s\n", strerror(errno)); + + // Register functions + luaL_newlib(L, location_functions); + + // Register AS type + register_as(L); + + lua_setfield(L, -2, "AS"); + + // Register Country type + register_country(L); + + lua_setfield(L, -2, "Country"); + + // Register Database type + register_database(L); + + lua_setfield(L, -2, "Database"); + + // Register DatabaseEnumerator type + register_database_enumerator(L); + + lua_setfield(L, -2, "DatabaseEnumerator"); + + // Register Network type + register_network(L); + + lua_setfield(L, -2, "Network"); + + // Set DATABASE_PATH + lua_pushstring(L, LIBLOC_DEFAULT_DATABASE_PATH); + lua_setfield(L, -2, "DATABASE_PATH"); + + // Add flags + lua_pushnumber(L, LOC_NETWORK_FLAG_ANONYMOUS_PROXY); + lua_setfield(L, -2, "NETWORK_FLAG_ANONYMOUS_PROXY"); + + lua_pushnumber(L, LOC_NETWORK_FLAG_SATELLITE_PROVIDER); + lua_setfield(L, -2, "NETWORK_FLAG_SATELLITE_PROVIDER"); + + lua_pushnumber(L, LOC_NETWORK_FLAG_ANYCAST); + lua_setfield(L, -2, "NETWORK_FLAG_ANYCAST"); + + lua_pushnumber(L, LOC_NETWORK_FLAG_DROP); + lua_setfield(L, -2, "NETWORK_FLAG_DROP"); + + return 1; +} diff --git a/src/lua/location.h b/src/lua/location.h new file mode 100644 index 0000000..0708988 --- /dev/null +++ b/src/lua/location.h @@ -0,0 +1,45 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LUA_LOCATION_LOCATION_H +#define LUA_LOCATION_LOCATION_H + +#include + +#include + +#include "compat.h" + +extern struct loc_ctx* ctx; + +int luaopen_location(lua_State* L); + +static inline int register_class(lua_State* L, + const char* name, const struct luaL_Reg* functions) { + // Create a new metatable + luaL_newmetatable(L, name); + + // Set functions + luaL_setfuncs(L, functions, 0); + + // Configure metatable + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + + return 1; +} + +#endif /* LUA_LOCATION_LOCATION_H */ diff --git a/src/lua/network.c b/src/lua/network.c new file mode 100644 index 0000000..2da6a1d --- /dev/null +++ b/src/lua/network.c @@ -0,0 +1,229 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include +#include + +#include +#include + +#include + +#include "location.h" +#include "compat.h" +#include "network.h" + +typedef struct network { + struct loc_network* network; +} Network; + +static Network* luaL_checknetwork(lua_State* L, int i) { + void* userdata = luaL_checkudata(L, i, "location.Network"); + + // Throw an error if the argument doesn't match + luaL_argcheck(L, userdata, i, "Network expected"); + + return (Network*)userdata; +} + +int create_network(lua_State* L, struct loc_network* network) { + // Allocate a new object + Network* self = (Network*)lua_newuserdata(L, sizeof(*self)); + + // Set metatable + luaL_setmetatable(L, "location.Network"); + + // Store network + self->network = loc_network_ref(network); + + return 1; +} + +static int Network_new(lua_State* L) { + struct loc_network* network = NULL; + const char* n = NULL; + int r; + + // Fetch the network + n = luaL_checkstring(L, 1); + + // Parse the string + r = loc_network_new_from_string(ctx, &network, n); + if (r) + return luaL_error(L, "Could not create network %s: %s\n", n, strerror(errno)); + + // Return the network + r = create_network(L, network); + loc_network_unref(network); + + return r; +} + +static int Network_gc(lua_State* L) { + Network* self = luaL_checknetwork(L, 1); + + // Free the network + if (self->network) { + loc_network_unref(self->network); + self->network = NULL; + } + + return 0; +} + +static int Network_tostring(lua_State* L) { + Network* self = luaL_checknetwork(L, 1); + + // Push string representation of the network + lua_pushstring(L, loc_network_str(self->network)); + + return 1; +} + +// ASN + +static int Network_get_asn(lua_State* L) { + Network* self = luaL_checknetwork(L, 1); + + uint32_t asn = loc_network_get_asn(self->network); + + // Push ASN + if (asn) + lua_pushnumber(L, asn); + else + lua_pushnil(L); + + return 1; +} + +// Family + +static int Network_get_family(lua_State* L) { + Network* self = luaL_checknetwork(L, 1); + + // Push family + lua_pushnumber(L, loc_network_address_family(self->network)); + + return 1; +} + +// Country Code + +static int Network_get_country_code(lua_State* L) { + Network* self = luaL_checknetwork(L, 1); + + const char* country_code = loc_network_get_country_code(self->network); + + // Push country code + if (country_code && *country_code) + lua_pushstring(L, country_code); + else + lua_pushnil(L); + + return 1; +} + +// Has Flag? + +static int Network_has_flag(lua_State* L) { + Network* self = luaL_checknetwork(L, 1); + + // Fetch flag + int flag = luaL_checknumber(L, 2); + + // Push result + lua_pushboolean(L, loc_network_has_flag(self->network, flag)); + + return 1; +} + +// Subnets + +static int Network_subnets(lua_State* L) { + struct loc_network* subnet1 = NULL; + struct loc_network* subnet2 = NULL; + int r; + + Network* self = luaL_checknetwork(L, 1); + + // Make subnets + r = loc_network_subnets(self->network, &subnet1, &subnet2); + if (r) + return luaL_error(L, "Could not create subnets of %s: %s\n", + loc_network_str(self->network), strerror(errno)); + + // Create a new table + lua_createtable(L, 2, 0); + + // Create the networks & push them onto the table + create_network(L, subnet1); + loc_network_unref(subnet1); + lua_rawseti(L, -2, 1); + + create_network(L, subnet2); + loc_network_unref(subnet2); + lua_rawseti(L, -2, 2); + + return 1; +} + +// Reverse Pointer + +static int Network_reverse_pointer(lua_State* L) { + char* rp = NULL; + + Network* self = luaL_checknetwork(L, 1); + + // Fetch the suffix + const char* suffix = luaL_optstring(L, 2, NULL); + + // Make the reverse pointer + rp = loc_network_reverse_pointer(self->network, suffix); + if (!rp) { + switch (errno) { + case ENOTSUP: + lua_pushnil(L); + return 1; + + default: + return luaL_error(L, "Could not create reverse pointer: %s\n", strerror(errno)); + } + } + + // Return the response + lua_pushstring(L, rp); + free(rp); + + return 1; +} + +static const struct luaL_Reg Network_functions[] = { + { "new", Network_new }, + { "get_asn", Network_get_asn }, + { "get_country_code", Network_get_country_code }, + { "get_family", Network_get_family }, + { "has_flag", Network_has_flag }, + { "reverse_pointer", Network_reverse_pointer }, + { "subnets", Network_subnets }, + { "__gc", Network_gc }, + { "__tostring", Network_tostring }, + { NULL, NULL }, +}; + +int register_network(lua_State* L) { + return register_class(L, "location.Network", Network_functions); +} diff --git a/src/lua/network.h b/src/lua/network.h new file mode 100644 index 0000000..130aa2f --- /dev/null +++ b/src/lua/network.h @@ -0,0 +1,29 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#ifndef LUA_LOCATION_NETWORK_H +#define LUA_LOCATION_NETWORK_H + +#include +#include + +#include + +int register_network(lua_State* L); + +int create_network(lua_State* L, struct loc_network* network); + +#endif /* LUA_LOCATION_NETWORK_H */ diff --git a/src/network-list.c b/src/network-list.c index bca4422..a8fcf0e 100644 --- a/src/network-list.c +++ b/src/network-list.c @@ -118,7 +118,7 @@ LOC_EXPORT void loc_network_list_dump(struct loc_network_list* list) { for (unsigned int i = 0; i < list->size; i++) { network = list->elements[i]; - INFO(list->ctx, "%4d: %s\n", + INFO(list->ctx, "%4u: %s\n", i, loc_network_str(network)); } } @@ -322,6 +322,7 @@ LOC_EXPORT int loc_network_list_merge( int loc_network_list_summarize(struct loc_ctx* ctx, const struct in6_addr* first, const struct in6_addr* last, struct loc_network_list** list) { + int bits; int r; if (!list) { @@ -351,41 +352,27 @@ int loc_network_list_summarize(struct loc_ctx* ctx, struct loc_network* network = NULL; struct in6_addr start = *first; - const int family_bit_length = loc_address_family_bit_length(family1); - while (loc_address_cmp(&start, last) <= 0) { - struct in6_addr num; - int bits1; - - // Find the number of trailing zeroes of the start address - if (loc_address_all_zeroes(&start)) - bits1 = family_bit_length; - else { - bits1 = loc_address_count_trailing_zero_bits(&start); - if (bits1 > family_bit_length) - bits1 = family_bit_length; - } - - // Subtract the start address from the last address and add one - // (i.e. how many addresses are in this network?) - r = loc_address_sub(&num, last, &start); - if (r) - return r; - - loc_address_increment(&num); - - // How many bits do we need to represent this address? - int bits2 = loc_address_bit_length(&num) - 1; + // Count how many leading bits the IP addresses have in common + bits = loc_address_common_bits(&start, last); + if (bits < 0) + return bits; - // Select the smaller one - int bits = (bits1 > bits2) ? bits2 : bits1; + // If the start and end address don't have any bits in common, we try + // to cut the subnet into halves and try again... + else if (bits == 0) + bits = 1; // Create a network - r = loc_network_new(ctx, &network, &start, family_bit_length - bits); + r = loc_network_new(ctx, &network, &start, bits); if (r) return r; - DEBUG(ctx, "Found network %s\n", loc_network_str(network)); + DEBUG(ctx, "Found network %s, %s -> %s\n", + loc_network_str(network), + loc_address_str(loc_network_get_first_address(network)), + loc_address_str(loc_network_get_last_address(network)) + ); // Push network on the list r = loc_network_list_push(*list, network); diff --git a/src/network-tree.c b/src/network-tree.c new file mode 100644 index 0000000..d6a2298 --- /dev/null +++ b/src/network-tree.c @@ -0,0 +1,680 @@ +/* + libloc - A library to determine the location of someone on the Internet + + Copyright (C) 2024 IPFire Development Team + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. +*/ + +#include +#include +#include + +#include +#include +#include +#include + +struct loc_network_tree { + struct loc_ctx* ctx; + int refcount; + + struct loc_network_tree_node* root; +}; + +struct loc_network_tree_node { + struct loc_ctx* ctx; + int refcount; + + struct loc_network_tree_node* zero; + struct loc_network_tree_node* one; + + struct loc_network* network; + + // Flags + enum loc_network_tree_node_flags { + NETWORK_TREE_NODE_DELETED = (1 << 0), + } flags; +}; + +int loc_network_tree_new(struct loc_ctx* ctx, struct loc_network_tree** tree) { + struct loc_network_tree* t = calloc(1, sizeof(*t)); + if (!t) + return 1; + + t->ctx = loc_ref(ctx); + t->refcount = 1; + + // Create the root node + int r = loc_network_tree_node_new(ctx, &t->root); + if (r) { + loc_network_tree_unref(t); + return r; + } + + DEBUG(t->ctx, "Network tree allocated at %p\n", t); + *tree = t; + return 0; +} + +static int loc_network_tree_node_has_flag(struct loc_network_tree_node* node, int flag) { + return node->flags & flag; +} + +struct loc_network_tree_node* loc_network_tree_get_root(struct loc_network_tree* tree) { + return loc_network_tree_node_ref(tree->root); +} + +static struct loc_network_tree_node* loc_network_tree_get_node(struct loc_network_tree_node* node, int path) { + struct loc_network_tree_node** n = NULL; + int r; + + switch (path) { + case 0: + n = &node->zero; + break; + + case 1: + n = &node->one; + break; + + default: + errno = EINVAL; + return NULL; + } + + // If the node existed, but has been deleted, we undelete it + if (*n && loc_network_tree_node_has_flag(*n, NETWORK_TREE_NODE_DELETED)) { + (*n)->flags &= ~NETWORK_TREE_NODE_DELETED; + + // If the desired node doesn't exist, yet, we will create it + } else if (!*n) { + r = loc_network_tree_node_new(node->ctx, n); + if (r) + return NULL; + } + + return *n; +} + +static struct loc_network_tree_node* loc_network_tree_get_path(struct loc_network_tree* tree, const struct in6_addr* address, unsigned int prefix) { + struct loc_network_tree_node* node = tree->root; + + for (unsigned int i = 0; i < prefix; i++) { + // Check if the ith bit is one or zero + node = loc_network_tree_get_node(node, loc_address_get_bit(address, i)); + } + + return node; +} + +static int __loc_network_tree_walk(struct loc_ctx* ctx, struct loc_network_tree_node* node, + int(*filter_callback)(struct loc_network* network, void* data), + int(*callback)(struct loc_network* network, void* data), void* data) { + int r; + + // If the node has been deleted, don't process it + if (loc_network_tree_node_has_flag(node, NETWORK_TREE_NODE_DELETED)) + return 0; + + // Finding a network ends the walk here + if (node->network) { + if (filter_callback) { + int f = filter_callback(node->network, data); + if (f < 0) + return f; + + // Skip network if filter function returns value greater than zero + if (f > 0) + return 0; + } + + r = callback(node->network, data); + if (r) + return r; + } + + // Walk down on the left side of the tree first + if (node->zero) { + r = __loc_network_tree_walk(ctx, node->zero, filter_callback, callback, data); + if (r) + return r; + } + + // Then walk on the other side + if (node->one) { + r = __loc_network_tree_walk(ctx, node->one, filter_callback, callback, data); + if (r) + return r; + } + + return 0; +} + +int loc_network_tree_walk(struct loc_network_tree* tree, + int(*filter_callback)(struct loc_network* network, void* data), + int(*callback)(struct loc_network* network, void* data), void* data) { + return __loc_network_tree_walk(tree->ctx, tree->root, filter_callback, callback, data); +} + +static void loc_network_tree_free(struct loc_network_tree* tree) { + DEBUG(tree->ctx, "Releasing network tree at %p\n", tree); + + loc_network_tree_node_unref(tree->root); + + loc_unref(tree->ctx); + free(tree); +} + +struct loc_network_tree* loc_network_tree_unref(struct loc_network_tree* tree) { + if (--tree->refcount > 0) + return tree; + + loc_network_tree_free(tree); + return NULL; +} + +static int __loc_network_tree_dump(struct loc_network* network, void* data) { + struct loc_ctx* ctx = data; + + DEBUG(ctx, "Dumping network at %p\n", network); + + const char* s = loc_network_str(network); + if (!s) + return 1; + + INFO(ctx, "%s\n", s); + + return 0; +} + +int loc_network_tree_dump(struct loc_network_tree* tree) { + DEBUG(tree->ctx, "Dumping network tree at %p\n", tree); + + return loc_network_tree_walk(tree, NULL, __loc_network_tree_dump, tree->ctx); +} + +int loc_network_tree_add_network(struct loc_network_tree* tree, struct loc_network* network) { + DEBUG(tree->ctx, "Adding network %p to tree %p\n", network, tree); + + const struct in6_addr* first_address = loc_network_get_first_address(network); + const unsigned int prefix = loc_network_raw_prefix(network); + + struct loc_network_tree_node* node = loc_network_tree_get_path(tree, first_address, prefix); + if (!node) { + ERROR(tree->ctx, "Could not find a node\n"); + return -ENOMEM; + } + + // Check if node has not been set before + if (node->network) { + DEBUG(tree->ctx, "There is already a network at this path: %s\n", + loc_network_str(node->network)); + return -EBUSY; + } + + // Point node to the network + node->network = loc_network_ref(network); + + return 0; +} + +static int loc_network_tree_delete_network( + struct loc_network_tree* tree, struct loc_network* network) { + struct loc_network_tree_node* node = NULL; + + DEBUG(tree->ctx, "Deleting network %s from tree...\n", loc_network_str(network)); + + const struct in6_addr* first_address = loc_network_get_first_address(network); + const unsigned int prefix = loc_network_raw_prefix(network); + + node = loc_network_tree_get_path(tree, first_address, prefix); + if (!node) { + ERROR(tree->ctx, "Network was not found in tree %s\n", loc_network_str(network)); + return 1; + } + + // Drop the network + if (node->network) { + loc_network_unref(node->network); + node->network = NULL; + } + + // Mark the node as deleted if it was a leaf + if (!node->zero && !node->one) + node->flags |= NETWORK_TREE_NODE_DELETED; + + return 0; +} + +static size_t __loc_network_tree_count_nodes(struct loc_network_tree_node* node) { + size_t counter = 1; + + // Don't count deleted nodes + if (loc_network_tree_node_has_flag(node, NETWORK_TREE_NODE_DELETED)) + return 0; + + if (node->zero) + counter += __loc_network_tree_count_nodes(node->zero); + + if (node->one) + counter += __loc_network_tree_count_nodes(node->one); + + return counter; +} + +size_t loc_network_tree_count_nodes(struct loc_network_tree* tree) { + return __loc_network_tree_count_nodes(tree->root); +} + +int loc_network_tree_node_new(struct loc_ctx* ctx, struct loc_network_tree_node** node) { + struct loc_network_tree_node* n = calloc(1, sizeof(*n)); + if (!n) + return -ENOMEM; + + n->ctx = loc_ref(ctx); + n->refcount = 1; + + n->zero = n->one = NULL; + + DEBUG(n->ctx, "Network node allocated at %p\n", n); + *node = n; + return 0; +} + +struct loc_network_tree_node* loc_network_tree_node_ref(struct loc_network_tree_node* node) { + if (node) + node->refcount++; + + return node; +} + +static void loc_network_tree_node_free(struct loc_network_tree_node* node) { + DEBUG(node->ctx, "Releasing network node at %p\n", node); + + if (node->network) + loc_network_unref(node->network); + + if (node->zero) + loc_network_tree_node_unref(node->zero); + + if (node->one) + loc_network_tree_node_unref(node->one); + + loc_unref(node->ctx); + free(node); +} + +struct loc_network_tree_node* loc_network_tree_node_unref(struct loc_network_tree_node* node) { + if (--node->refcount > 0) + return node; + + loc_network_tree_node_free(node); + return NULL; +} + +struct loc_network_tree_node* loc_network_tree_node_get(struct loc_network_tree_node* node, unsigned int index) { + if (index == 0) + node = node->zero; + else + node = node->one; + + if (!node) + return NULL; + + return loc_network_tree_node_ref(node); +} + +int loc_network_tree_node_is_leaf(struct loc_network_tree_node* node) { + return (!!node->network); +} + +struct loc_network* loc_network_tree_node_get_network(struct loc_network_tree_node* node) { + return loc_network_ref(node->network); +} + +/* + Merge the tree! +*/ + +struct loc_network_tree_merge_ctx { + struct loc_network_tree* tree; + struct loc_network_list* networks; + unsigned int merged; +}; + +static int loc_network_tree_merge_step(struct loc_network* network, void* data) { + struct loc_network_tree_merge_ctx* ctx = (struct loc_network_tree_merge_ctx*)data; + struct loc_network* n = NULL; + struct loc_network* m = NULL; + int r; + + // How many networks do we have? + size_t i = loc_network_list_size(ctx->networks); + + // If the list is empty, just add the network + if (i == 0) + return loc_network_list_push(ctx->networks, network); + + while (i--) { + // Fetch the last network of the list + n = loc_network_list_get(ctx->networks, i); + + // Try to merge the two networks + r = loc_network_merge(&m, n, network); + if (r) + goto ERROR; + + // Did we get a result? + if (m) { + DEBUG(ctx->tree->ctx, "Merged networks %s + %s -> %s\n", + loc_network_str(n), loc_network_str(network), loc_network_str(m)); + + // Add the new network + r = loc_network_tree_add_network(ctx->tree, m); + switch (r) { + case 0: + break; + + // There might already be a network + case -EBUSY: + r = 0; + goto ERROR; + + default: + goto ERROR; + } + + // Remove the merge networks + r = loc_network_tree_delete_network(ctx->tree, network); + if (r) + goto ERROR; + + r = loc_network_tree_delete_network(ctx->tree, n); + if (r) + goto ERROR; + + // Remove the previous network from the stack + r = loc_network_list_remove(ctx->networks, n); + if (r) + goto ERROR; + + // Count merges + ctx->merged++; + + // Try merging the new network with others + r = loc_network_tree_merge_step(m, data); + if (r) + goto ERROR; + + // Add the new network to the stack + r = loc_network_list_push(ctx->networks, m); + if (r) + goto ERROR; + + loc_network_unref(m); + m = NULL; + + // Once we have found a merge, we are done + break; + + // If we could not merge the two networks, we add the current one + } else { + r = loc_network_list_push(ctx->networks, network); + if (r) + goto ERROR; + } + + loc_network_unref(n); + n = NULL; + } + + const unsigned int prefix = loc_network_prefix(network); + + // Remove any networks that we cannot merge + loc_network_list_remove_with_prefix_smaller_than(ctx->networks, prefix); + +ERROR: + if (m) + loc_network_unref(m); + if (n) + loc_network_unref(n); + + return r; +} + +static int loc_network_tree_merge(struct loc_network_tree* tree) { + struct loc_network_tree_merge_ctx ctx = { + .tree = tree, + .networks = NULL, + .merged = 0, + }; + unsigned int total_merged = 0; + int r; + + // Create a new list + r = loc_network_list_new(tree->ctx, &ctx.networks); + if (r) + goto ERROR; + + // This is a fix for a very interesting problem which only occurs on non-Debian + // systems where the algorithm seems to miss some merges. If we run it multiple + // times it will however find them... + do { + // Reset merges + ctx.merged = 0; + + // Walk through the entire tree + r = loc_network_tree_walk(tree, NULL, loc_network_tree_merge_step, &ctx); + if (r) + goto ERROR; + + // Count all merges + total_merged += ctx.merged; + } while (ctx.merged > 0); + + DEBUG(tree->ctx, "%u network(s) have been merged\n", total_merged); + +ERROR: + if (ctx.networks) + loc_network_list_unref(ctx.networks); + + return r; +} + +/* + Deduplicate the tree +*/ + +struct loc_network_tree_dedup_ctx { + struct loc_network_tree* tree; + struct loc_network_list* stack; + unsigned int* removed; + int family; +}; + +static int loc_network_tree_dedup_step(struct loc_network* network, void* data) { + struct loc_network_tree_dedup_ctx* ctx = (struct loc_network_tree_dedup_ctx*)data; + struct loc_network* n = NULL; + int r; + + // Walk through all networks on the stack... + for (int i = loc_network_list_size(ctx->stack) - 1; i >= 0; i--) { + n = loc_network_list_get(ctx->stack, i); + + // Is network a subnet? + if (loc_network_is_subnet(n, network)) { + // Do all properties match? + if (loc_network_properties_cmp(n, network) == 0) { + r = loc_network_tree_delete_network(ctx->tree, network); + if (r) + goto END; + + // Count + (*ctx->removed)++; + + // Once we removed the subnet, we are done + goto END; + } + + // Once we found a subnet, we are done + break; + } + + // If the network wasn't a subnet, we can remove it, + // because we won't ever see a subnet again. + r = loc_network_list_remove(ctx->stack, n); + if (r) + goto END; + + loc_network_unref(n); + n = NULL; + } + + // If network did not get removed, we push it into the stack + r = loc_network_list_push(ctx->stack, network); + if (r) + return r; + +END: + if (n) + loc_network_unref(n); + + return r; +} + +static int loc_network_tree_dedup_filter(struct loc_network* network, void* data) { + const struct loc_network_tree_dedup_ctx* ctx = data; + + // Match address family + return ctx->family == loc_network_address_family(network); +} + +static int loc_network_tree_dedup_one(struct loc_network_tree* tree, + const int family, unsigned int* removed) { + struct loc_network_tree_dedup_ctx ctx = { + .tree = tree, + .stack = NULL, + .removed = removed, + .family = family, + }; + int r; + + r = loc_network_list_new(tree->ctx, &ctx.stack); + if (r) + return r; + + // Walk through the entire tree + r = loc_network_tree_walk(tree, + loc_network_tree_dedup_filter, loc_network_tree_dedup_step, &ctx); + if (r) + goto ERROR; + +ERROR: + if (ctx.stack) + loc_network_list_unref(ctx.stack); + + return r; +} + +static int loc_network_tree_dedup(struct loc_network_tree* tree) { + unsigned int removed = 0; + int r; + + r = loc_network_tree_dedup_one(tree, AF_INET6, &removed); + if (r) + return r; + + r = loc_network_tree_dedup_one(tree, AF_INET, &removed); + if (r) + return r; + + DEBUG(tree->ctx, "%u network(s) have been removed\n", removed); + + return 0; +} + +static int loc_network_tree_delete_node(struct loc_network_tree* tree, + struct loc_network_tree_node** node) { + struct loc_network_tree_node* n = *node; + int r0 = 1; + int r1 = 1; + + // Return for nodes that have already been deleted + if (loc_network_tree_node_has_flag(n, NETWORK_TREE_NODE_DELETED)) + goto DELETE; + + // Delete zero + if (n->zero) { + r0 = loc_network_tree_delete_node(tree, &n->zero); + if (r0 < 0) + return r0; + } + + // Delete one + if (n->one) { + r1 = loc_network_tree_delete_node(tree, &n->one); + if (r1 < 0) + return r1; + } + + // Don't delete this node if we are a leaf + if (n->network) + return 0; + + // Don't delete this node if has child nodes that we need + if (!r0 || !r1) + return 0; + + // Don't delete root + if (tree->root == n) + return 0; + +DELETE: + // It is now safe to delete the node + loc_network_tree_node_unref(n); + *node = NULL; + + return 1; +} + +static int loc_network_tree_delete_nodes(struct loc_network_tree* tree) { + int r; + + r = loc_network_tree_delete_node(tree, &tree->root); + if (r < 0) + return r; + + return 0; +} + +int loc_network_tree_cleanup(struct loc_network_tree* tree) { + int r; + + // Deduplicate the tree + r = loc_network_tree_dedup(tree); + if (r) + return r; + + // Merge networks + r = loc_network_tree_merge(tree); + if (r) { + ERROR(tree->ctx, "Could not merge networks: %m\n"); + return r; + } + + // Delete any unneeded nodes + r = loc_network_tree_delete_nodes(tree); + if (r) + return r; + + return 0; +} diff --git a/src/network.c b/src/network.c index 47fe735..3cc2d3c 100644 --- a/src/network.c +++ b/src/network.c @@ -51,6 +51,8 @@ struct loc_network { LOC_EXPORT int loc_network_new(struct loc_ctx* ctx, struct loc_network** network, struct in6_addr* address, unsigned int prefix) { + struct loc_network* n = NULL; + // Validate the prefix if (!loc_address_valid_prefix(address, prefix)) { ERROR(ctx, "Invalid prefix in %s: %u\n", loc_address_str(address), prefix); @@ -58,7 +60,8 @@ LOC_EXPORT int loc_network_new(struct loc_ctx* ctx, struct loc_network** network return 1; } - struct loc_network* n = calloc(1, sizeof(*n)); + // Allocate a new network + n = calloc(1, sizeof(*n)); if (!n) return 1; @@ -72,7 +75,7 @@ LOC_EXPORT int loc_network_new(struct loc_ctx* ctx, struct loc_network** network n->prefix = prefix; // Convert the prefix into a bitmask - struct in6_addr bitmask = loc_prefix_to_bitmask(n->prefix); + const struct in6_addr bitmask = loc_prefix_to_bitmask(n->prefix); // Store the first and last address in the network n->first_address = loc_address_and(address, &bitmask); @@ -162,6 +165,10 @@ LOC_EXPORT unsigned int loc_network_prefix(struct loc_network* network) { return 0; } +unsigned int loc_network_raw_prefix(struct loc_network* network) { + return network->prefix; +} + LOC_EXPORT const struct in6_addr* loc_network_get_first_address(struct loc_network* network) { return &network->first_address; } @@ -264,7 +271,7 @@ LOC_EXPORT int loc_network_cmp(struct loc_network* self, struct loc_network* oth return 0; } -static int loc_network_properties_cmp(struct loc_network* self, struct loc_network* other) { +int loc_network_properties_cmp(struct loc_network* self, struct loc_network* other) { int r; // Check country code @@ -334,7 +341,7 @@ LOC_EXPORT int loc_network_subnets(struct loc_network* network, // Check if the new prefix is valid if (!loc_address_valid_prefix(&network->first_address, prefix)) { - ERROR(network->ctx, "Invalid prefix: %d\n", prefix); + ERROR(network->ctx, "Invalid prefix: %u\n", prefix); errno = EINVAL; return 1; } @@ -563,9 +570,8 @@ LOC_EXPORT struct loc_network_list* loc_network_exclude_list( loc_network_unref(subnet); } - if (passed) { - r = loc_network_list_push(subnets, subnet_to_check); - } + if (passed) + loc_network_list_push(subnets, subnet_to_check); loc_network_unref(subnet_to_check); } @@ -575,7 +581,7 @@ LOC_EXPORT struct loc_network_list* loc_network_exclude_list( return subnets; } -static int loc_network_merge(struct loc_network** n, +int loc_network_merge(struct loc_network** n, struct loc_network* n1, struct loc_network* n2) { struct loc_network* network = NULL; struct in6_addr address; @@ -584,6 +590,8 @@ static int loc_network_merge(struct loc_network** n, // Reset pointer *n = NULL; + DEBUG(n1->ctx, "Attempting to merge %s and %s\n", loc_network_str(n1), loc_network_str(n2)); + // Family must match if (n1->family != n2->family) return 0; @@ -596,14 +604,18 @@ static int loc_network_merge(struct loc_network** n, if (!n1->prefix || !n2->prefix) return 0; - const unsigned int prefix = loc_network_prefix(n1); + const size_t prefix = loc_network_prefix(n1); // How many bits do we need to represent this address? - const size_t bitlength = loc_address_bit_length(&n1->first_address) - 1; + const size_t bitlength = loc_address_bit_length(&n1->first_address); // We cannot shorten this any more - if (bitlength == prefix) + if (bitlength >= prefix) { + DEBUG(n1->ctx, "Cannot shorten this any further because we need at least %zu bits," + " but only have %zu\n", bitlength, prefix); + return 0; + } // Increment the last address of the first network address = n1->last_address; @@ -673,7 +685,7 @@ int loc_network_new_from_database_v1(struct loc_ctx* ctx, struct loc_network** n uint32_t asn = be32toh(dbobj->asn); r = loc_network_set_asn(*network, asn); if (r) { - ERROR(ctx, "Could not set ASN: %d\n", asn); + ERROR(ctx, "Could not set ASN: %u\n", asn); return r; } @@ -688,584 +700,121 @@ int loc_network_new_from_database_v1(struct loc_ctx* ctx, struct loc_network** n return 0; } -struct loc_network_tree { - struct loc_ctx* ctx; - int refcount; - - struct loc_network_tree_node* root; -}; - -struct loc_network_tree_node { - struct loc_ctx* ctx; - int refcount; - - struct loc_network_tree_node* zero; - struct loc_network_tree_node* one; - - struct loc_network* network; - - // Set if deleted - int deleted:1; -}; - -int loc_network_tree_new(struct loc_ctx* ctx, struct loc_network_tree** tree) { - struct loc_network_tree* t = calloc(1, sizeof(*t)); - if (!t) - return 1; - - t->ctx = loc_ref(ctx); - t->refcount = 1; - - // Create the root node - int r = loc_network_tree_node_new(ctx, &t->root); - if (r) { - loc_network_tree_unref(t); - return r; - } - - DEBUG(t->ctx, "Network tree allocated at %p\n", t); - *tree = t; - return 0; -} - -struct loc_network_tree_node* loc_network_tree_get_root(struct loc_network_tree* tree) { - return loc_network_tree_node_ref(tree->root); -} - -static struct loc_network_tree_node* loc_network_tree_get_node(struct loc_network_tree_node* node, int path) { - struct loc_network_tree_node** n = NULL; +static char* loc_network_reverse_pointer6(struct loc_network* network, const char* suffix) { + char* buffer = NULL; int r; - switch (path) { - case 0: - n = &node->zero; - break; - - case 1: - n = &node->one; - break; - - default: - errno = EINVAL; - return NULL; - } - - // If the node existed, but has been deleted, we undelete it - if (*n && (*n)->deleted) { - (*n)->deleted = 0; - - // If the desired node doesn't exist, yet, we will create it - } else if (!*n) { - r = loc_network_tree_node_new(node->ctx, n); - if (r) - return NULL; - } - - return *n; -} - -static struct loc_network_tree_node* loc_network_tree_get_path(struct loc_network_tree* tree, const struct in6_addr* address, unsigned int prefix) { - struct loc_network_tree_node* node = tree->root; - - for (unsigned int i = 0; i < prefix; i++) { - // Check if the ith bit is one or zero - node = loc_network_tree_get_node(node, loc_address_get_bit(address, i)); - } - - return node; -} - -static int __loc_network_tree_walk(struct loc_ctx* ctx, struct loc_network_tree_node* node, - int(*filter_callback)(struct loc_network* network, void* data), - int(*callback)(struct loc_network* network, void* data), void* data) { - int r; - - // If the node has been deleted, don't process it - if (node->deleted) - return 0; - - // Finding a network ends the walk here - if (node->network) { - if (filter_callback) { - int f = filter_callback(node->network, data); - if (f < 0) - return f; - - // Skip network if filter function returns value greater than zero - if (f > 0) - return 0; - } - - r = callback(node->network, data); - if (r) - return r; - } - - // Walk down on the left side of the tree first - if (node->zero) { - r = __loc_network_tree_walk(ctx, node->zero, filter_callback, callback, data); - if (r) - return r; - } - - // Then walk on the other side - if (node->one) { - r = __loc_network_tree_walk(ctx, node->one, filter_callback, callback, data); - if (r) - return r; - } - - return 0; -} - -int loc_network_tree_walk(struct loc_network_tree* tree, - int(*filter_callback)(struct loc_network* network, void* data), - int(*callback)(struct loc_network* network, void* data), void* data) { - return __loc_network_tree_walk(tree->ctx, tree->root, filter_callback, callback, data); -} - -static void loc_network_tree_free(struct loc_network_tree* tree) { - DEBUG(tree->ctx, "Releasing network tree at %p\n", tree); - - loc_network_tree_node_unref(tree->root); - - loc_unref(tree->ctx); - free(tree); -} - -struct loc_network_tree* loc_network_tree_unref(struct loc_network_tree* tree) { - if (--tree->refcount > 0) - return tree; - - loc_network_tree_free(tree); - return NULL; -} - -static int __loc_network_tree_dump(struct loc_network* network, void* data) { - DEBUG(network->ctx, "Dumping network at %p\n", network); - - const char* s = loc_network_str(network); - if (!s) - return 1; - - INFO(network->ctx, "%s\n", s); - - return 0; -} - -int loc_network_tree_dump(struct loc_network_tree* tree) { - DEBUG(tree->ctx, "Dumping network tree at %p\n", tree); - - return loc_network_tree_walk(tree, NULL, __loc_network_tree_dump, NULL); -} - -int loc_network_tree_add_network(struct loc_network_tree* tree, struct loc_network* network) { - DEBUG(tree->ctx, "Adding network %p to tree %p\n", network, tree); + unsigned int prefix = loc_network_prefix(network); - struct loc_network_tree_node* node = loc_network_tree_get_path(tree, - &network->first_address, network->prefix); - if (!node) { - ERROR(tree->ctx, "Could not find a node\n"); - return -ENOMEM; - } - - // Check if node has not been set before - if (node->network) { - DEBUG(tree->ctx, "There is already a network at this path: %s\n", - loc_network_str(node->network)); - return -EBUSY; + // Must border on a nibble + if (prefix % 4) { + errno = ENOTSUP; + return NULL; } - // Point node to the network - node->network = loc_network_ref(network); - - return 0; -} - -static int loc_network_tree_delete_network( - struct loc_network_tree* tree, struct loc_network* network) { - struct loc_network_tree_node* node = NULL; - - DEBUG(tree->ctx, "Deleting network %s from tree...\n", loc_network_str(network)); + if (!suffix) + suffix = "ip6.arpa."; - node = loc_network_tree_get_path(tree, &network->first_address, network->prefix); - if (!node) { - ERROR(tree->ctx, "Network was not found in tree %s\n", loc_network_str(network)); - return 1; - } + // Initialize the buffer + r = asprintf(&buffer, "%s", suffix); + if (r < 0) + goto ERROR; - // Drop the network - if (node->network) { - loc_network_unref(node->network); - node->network = NULL; + for (unsigned int i = 0; i < (prefix / 4); i++) { + r = asprintf(&buffer, "%x.%s", + (unsigned int)loc_address_get_nibble(&network->first_address, i), buffer); + if (r < 0) + goto ERROR; } - // Mark the node as deleted if it was a leaf - if (!node->zero && !node->one) - node->deleted = 1; - - return 0; -} - -static size_t __loc_network_tree_count_nodes(struct loc_network_tree_node* node) { - size_t counter = 1; - - // Don't count deleted nodes - if (node->deleted) - return 0; - - if (node->zero) - counter += __loc_network_tree_count_nodes(node->zero); - - if (node->one) - counter += __loc_network_tree_count_nodes(node->one); - - return counter; -} - -size_t loc_network_tree_count_nodes(struct loc_network_tree* tree) { - return __loc_network_tree_count_nodes(tree->root); -} - -int loc_network_tree_node_new(struct loc_ctx* ctx, struct loc_network_tree_node** node) { - struct loc_network_tree_node* n = calloc(1, sizeof(*n)); - if (!n) - return -ENOMEM; - - n->ctx = loc_ref(ctx); - n->refcount = 1; - - n->zero = n->one = NULL; - - DEBUG(n->ctx, "Network node allocated at %p\n", n); - *node = n; - return 0; -} - -struct loc_network_tree_node* loc_network_tree_node_ref(struct loc_network_tree_node* node) { - if (node) - node->refcount++; - - return node; -} - -static void loc_network_tree_node_free(struct loc_network_tree_node* node) { - DEBUG(node->ctx, "Releasing network node at %p\n", node); - - if (node->network) - loc_network_unref(node->network); - - if (node->zero) - loc_network_tree_node_unref(node->zero); - - if (node->one) - loc_network_tree_node_unref(node->one); - - loc_unref(node->ctx); - free(node); -} - -struct loc_network_tree_node* loc_network_tree_node_unref(struct loc_network_tree_node* node) { - if (--node->refcount > 0) - return node; - - loc_network_tree_node_free(node); - return NULL; -} - -struct loc_network_tree_node* loc_network_tree_node_get(struct loc_network_tree_node* node, unsigned int index) { - if (index == 0) - node = node->zero; - else - node = node->one; - - if (!node) - return NULL; - - return loc_network_tree_node_ref(node); -} - -int loc_network_tree_node_is_leaf(struct loc_network_tree_node* node) { - return (!!node->network); -} - -struct loc_network* loc_network_tree_node_get_network(struct loc_network_tree_node* node) { - return loc_network_ref(node->network); -} - -/* - Merge the tree! -*/ - -struct loc_network_tree_merge_ctx { - struct loc_network_tree* tree; - struct loc_network_list* networks; - unsigned int merged; -}; - -static int loc_network_tree_merge_step(struct loc_network* network, void* data) { - struct loc_network_tree_merge_ctx* ctx = (struct loc_network_tree_merge_ctx*)data; - struct loc_network* n = NULL; - struct loc_network* m = NULL; - int r; - - // How many networks do we have? - size_t i = loc_network_list_size(ctx->networks); - - // If the list is empty, just add the network - if (i == 0) - return loc_network_list_push(ctx->networks, network); - - while (i--) { - // Fetch the last network of the list - n = loc_network_list_get(ctx->networks, i); - - // Try to merge the two networks - r = loc_network_merge(&m, n, network); - if (r) + // Add the asterisk + if (prefix < 128) { + r = asprintf(&buffer, "*.%s", buffer); + if (r < 0) goto ERROR; - - // Did we get a result? - if (m) { - DEBUG(ctx->tree->ctx, "Merged networks %s + %s -> %s\n", - loc_network_str(n), loc_network_str(network), loc_network_str(m)); - - // Add the new network - r = loc_network_tree_add_network(ctx->tree, m); - switch (r) { - case 0: - break; - - // There might already be a network - case -EBUSY: - r = 0; - goto ERROR; - - default: - goto ERROR; - } - - // Remove the merge networks - r = loc_network_tree_delete_network(ctx->tree, network); - if (r) - goto ERROR; - - r = loc_network_tree_delete_network(ctx->tree, n); - if (r) - goto ERROR; - - // Add the new network to the stack - r = loc_network_list_push(ctx->networks, m); - if (r) - goto ERROR; - - // Remove the previous network from the stack - r = loc_network_list_remove(ctx->networks, n); - if (r) - goto ERROR; - - // Count merges - ctx->merged++; - - // Try merging the new network with others - r = loc_network_tree_merge_step(m, data); - if (r) - goto ERROR; - - loc_network_unref(m); - m = NULL; - - // Once we have found a merge, we are done - break; - - // If we could not merge the two networks, we add the current one - } else { - r = loc_network_list_push(ctx->networks, network); - if (r) - goto ERROR; - } - - loc_network_unref(n); - n = NULL; } - const unsigned int prefix = loc_network_prefix(network); - - // Remove any networks that we cannot merge - loc_network_list_remove_with_prefix_smaller_than(ctx->networks, prefix); + return buffer; ERROR: - if (m) - loc_network_unref(m); - if (n) - loc_network_unref(n); + if (buffer) + free(buffer); - return r; + return NULL; } -static int loc_network_tree_merge(struct loc_network_tree* tree) { - struct loc_network_tree_merge_ctx ctx = { - .tree = tree, - .networks = NULL, - .merged = 0, - }; +static char* loc_network_reverse_pointer4(struct loc_network* network, const char* suffix) { + char* buffer = NULL; int r; - // Create a new list - r = loc_network_list_new(tree->ctx, &ctx.networks); - if (r) - goto ERROR; - - // Walk through the entire tree - r = loc_network_tree_walk(tree, NULL, loc_network_tree_merge_step, &ctx); - if (r) - goto ERROR; - - DEBUG(tree->ctx, "%u network(s) have been merged\n", ctx.merged); - -ERROR: - if (ctx.networks) - loc_network_list_unref(ctx.networks); - - return r; -} - -/* - Deduplicate the tree -*/ - -struct loc_network_tree_dedup_ctx { - struct loc_network_tree* tree; - struct loc_network* network; - unsigned int removed; -}; - -static int loc_network_tree_dedup_step(struct loc_network* network, void* data) { - struct loc_network_tree_dedup_ctx* ctx = (struct loc_network_tree_dedup_ctx*)data; - - // First call when we have not seen any networks, yet - if (!ctx->network) { - ctx->network = loc_network_ref(network); - return 0; - } - - // If network is a subnet of ctx->network, and all properties match, - // we can drop the network. - if (loc_network_is_subnet(ctx->network, network)) { - if (loc_network_properties_cmp(ctx->network, network) == 0) { - // Increment counter - ctx->removed++; - - // Remove the network - return loc_network_tree_delete_network(ctx->tree, network); - } + unsigned int prefix = loc_network_prefix(network); - return 0; + // Must border on an octet + if (prefix % 8) { + errno = ENOTSUP; + return NULL; } - // Drop the reference to the previous network - if (ctx->network) - loc_network_unref(ctx->network); - ctx->network = loc_network_ref(network); - - return 0; -} - -static int loc_network_tree_dedup(struct loc_network_tree* tree) { - struct loc_network_tree_dedup_ctx ctx = { - .tree = tree, - .network = NULL, - .removed = 0, - }; - int r; + if (!suffix) + suffix = "in-addr.arpa."; - // Walk through the entire tree - r = loc_network_tree_walk(tree, NULL, loc_network_tree_dedup_step, &ctx); - if (r) - goto ERROR; + switch (prefix) { + case 32: + r = asprintf(&buffer, "%d.%d.%d.%d.%s", + loc_address_get_octet(&network->first_address, 3), + loc_address_get_octet(&network->first_address, 2), + loc_address_get_octet(&network->first_address, 1), + loc_address_get_octet(&network->first_address, 0), + suffix); + break; - DEBUG(tree->ctx, "%u network(s) have been removed\n", ctx.removed); + case 24: + r = asprintf(&buffer, "*.%d.%d.%d.%s", + loc_address_get_octet(&network->first_address, 2), + loc_address_get_octet(&network->first_address, 1), + loc_address_get_octet(&network->first_address, 0), + suffix); + break; -ERROR: - if (ctx.network) - loc_network_unref(ctx.network); + case 16: + r = asprintf(&buffer, "*.%d.%d.%s", + loc_address_get_octet(&network->first_address, 1), + loc_address_get_octet(&network->first_address, 0), + suffix); + break; - return r; -} + case 8: + r = asprintf(&buffer, "*.%d.%s", + loc_address_get_octet(&network->first_address, 0), + suffix); + break; -static int loc_network_tree_delete_node(struct loc_network_tree* tree, - struct loc_network_tree_node** node) { - struct loc_network_tree_node* n = *node; - int r0 = 1; - int r1 = 1; - - // Return for nodes that have already been deleted - if (n->deleted) - goto DELETE; - - // Delete zero - if (n->zero) { - r0 = loc_network_tree_delete_node(tree, &n->zero); - if (r0 < 0) - return r0; - } + case 0: + r = asprintf(&buffer, "*.%s", suffix); + break; - // Delete one - if (n->one) { - r1 = loc_network_tree_delete_node(tree, &n->one); - if (r1 < 0) - return r1; + // To make the compiler happy + default: + return NULL; } - // Don't delete this node if we are a leaf - if (n->network) - return 0; - - // Don't delete this node if has child nodes that we need - if (!r0 || !r1) - return 0; - - // Don't delete root - if (tree->root == n) - return 0; - -DELETE: - // It is now safe to delete the node - loc_network_tree_node_unref(n); - *node = NULL; - - return 1; -} - -static int loc_network_tree_delete_nodes(struct loc_network_tree* tree) { - int r; - - r = loc_network_tree_delete_node(tree, &tree->root); if (r < 0) - return r; + return NULL; - return 0; + return buffer; } -int loc_network_tree_cleanup(struct loc_network_tree* tree) { - int r; +LOC_EXPORT char* loc_network_reverse_pointer(struct loc_network* network, const char* suffix) { + switch (network->family) { + case AF_INET6: + return loc_network_reverse_pointer6(network, suffix); - // Deduplicate the tree - r = loc_network_tree_dedup(tree); - if (r) - return r; + case AF_INET: + return loc_network_reverse_pointer4(network, suffix); - // Merge networks - r = loc_network_tree_merge(tree); - if (r) { - ERROR(tree->ctx, "Could not merge networks: %m\n"); - return r; + default: + break; } - // Delete any unneeded nodes - r = loc_network_tree_delete_nodes(tree); - if (r) - return r; - - return 0; + return NULL; } diff --git a/src/perl/Location.xs b/src/perl/Location.xs index 6f21f2b..896cbcb 100644 --- a/src/perl/Location.xs +++ b/src/perl/Location.xs @@ -166,18 +166,26 @@ lookup_country_code(db, address) char* address; CODE: + struct loc_network *network = NULL; + const char* country_code = NULL; RETVAL = &PL_sv_undef; // Lookup network - struct loc_network *network; int err = loc_database_lookup_from_string(db, address, &network); - if (!err) { - // Extract the country code - const char* country_code = loc_network_get_country_code(network); - RETVAL = newSVpv(country_code, strlen(country_code)); + if (err) { + goto END; + } + + // Extract the country code if we have found a network + if (network) { + country_code = loc_network_get_country_code(network); + if (country_code) + RETVAL = newSVpv(country_code, strlen(country_code)); loc_network_unref(network); } + + END: OUTPUT: RETVAL @@ -188,6 +196,7 @@ lookup_network_has_flag(db, address, flag) char* flag; CODE: + struct loc_network *network = NULL; RETVAL = false; enum loc_network_flags iv = 0; @@ -204,11 +213,13 @@ lookup_network_has_flag(db, address, flag) croak("Invalid flag"); // Lookup network - struct loc_network *network; int err = loc_database_lookup_from_string(db, address, &network); + if (err) { + goto END; + } - if (!err) { - // Check if the network has the given flag. + // Check if the network has the given flag + if (network) { if (loc_network_has_flag(network, iv)) { RETVAL = true; } @@ -216,6 +227,7 @@ lookup_network_has_flag(db, address, flag) loc_network_unref(network); } + END: OUTPUT: RETVAL @@ -225,13 +237,17 @@ lookup_asn(db, address) char* address; CODE: + struct loc_network *network = NULL; RETVAL = &PL_sv_undef; // Lookup network - struct loc_network *network; int err = loc_database_lookup_from_string(db, address, &network); - if (!err) { - // Extract the ASN + if (err) { + goto END; + } + + // Extract the ASN + if (network) { unsigned int as_number = loc_network_get_asn(network); if (as_number > 0) { RETVAL = newSViv(as_number); @@ -239,6 +255,8 @@ lookup_asn(db, address) loc_network_unref(network); } + + END: OUTPUT: RETVAL diff --git a/src/python/as.c b/src/python/as.c index 4cf9987..2f4b26f 100644 --- a/src/python/as.c +++ b/src/python/as.c @@ -102,8 +102,16 @@ static int AS_set_name(ASObject* self, PyObject* value) { return 0; } -static PyObject* AS_richcompare(ASObject* self, ASObject* other, int op) { - int r = loc_as_cmp(self->as, other->as); +static PyObject* AS_richcompare(ASObject* self, PyObject* other, int op) { + int r; + + // Check for type + if (!PyObject_IsInstance(other, (PyObject *)&ASType)) + Py_RETURN_NOTIMPLEMENTED; + + ASObject* o = (ASObject*)other; + + r = loc_as_cmp(self->as, o->as); switch (op) { case Py_EQ: @@ -125,6 +133,12 @@ static PyObject* AS_richcompare(ASObject* self, ASObject* other, int op) { Py_RETURN_NOTIMPLEMENTED; } +static Py_hash_t AS_hash(ASObject* self) { + uint32_t number = loc_as_get_number(self->as); + + return number; +} + static struct PyGetSetDef AS_getsetters[] = { { "name", @@ -156,4 +170,5 @@ PyTypeObject ASType = { .tp_repr = (reprfunc)AS_repr, .tp_str = (reprfunc)AS_str, .tp_richcompare = (richcmpfunc)AS_richcompare, + .tp_hash = (hashfunc)AS_hash, }; diff --git a/src/python/country.c b/src/python/country.c index 4bb6a31..7114846 100644 --- a/src/python/country.c +++ b/src/python/country.c @@ -81,6 +81,10 @@ static PyObject* Country_str(CountryObject* self) { static PyObject* Country_get_name(CountryObject* self) { const char* name = loc_country_get_name(self->country); + // Return None if no name has been set + if (!name) + Py_RETURN_NONE; + return PyUnicode_FromString(name); } @@ -99,6 +103,9 @@ static int Country_set_name(CountryObject* self, PyObject* value) { static PyObject* Country_get_continent_code(CountryObject* self) { const char* code = loc_country_get_continent_code(self->country); + if (!code) + Py_RETURN_NONE; + return PyUnicode_FromString(code); } @@ -114,8 +121,16 @@ static int Country_set_continent_code(CountryObject* self, PyObject* value) { return 0; } -static PyObject* Country_richcompare(CountryObject* self, CountryObject* other, int op) { - int r = loc_country_cmp(self->country, other->country); +static PyObject* Country_richcompare(CountryObject* self, PyObject* other, int op) { + int r; + + // Check for type + if (!PyObject_IsInstance(other, (PyObject *)&CountryType)) + Py_RETURN_NOTIMPLEMENTED; + + CountryObject* o = (CountryObject*)other; + + r = loc_country_cmp(self->country, o->country); switch (op) { case Py_EQ: @@ -137,6 +152,22 @@ static PyObject* Country_richcompare(CountryObject* self, CountryObject* other, Py_RETURN_NOTIMPLEMENTED; } +static Py_hash_t Country_hash(CountryObject* self) { + PyObject* code = NULL; + Py_hash_t hash = 0; + + // Fetch the code as Python string + code = Country_get_code(self); + if (!code) + return -1; + + // Fetch the hash of that string + hash = PyObject_Hash(code); + Py_DECREF(code); + + return hash; +} + static struct PyGetSetDef Country_getsetters[] = { { "code", @@ -175,4 +206,5 @@ PyTypeObject CountryType = { .tp_repr = (reprfunc)Country_repr, .tp_str = (reprfunc)Country_str, .tp_richcompare = (richcmpfunc)Country_richcompare, + .tp_hash = (hashfunc)Country_hash, }; diff --git a/src/python/database.c b/src/python/database.c index d6ee4d0..c0d4264 100644 --- a/src/python/database.c +++ b/src/python/database.c @@ -201,35 +201,36 @@ static PyObject* Database_get_country(DatabaseObject* self, PyObject* args) { static PyObject* Database_lookup(DatabaseObject* self, PyObject* args) { struct loc_network* network = NULL; const char* address = NULL; + int r; if (!PyArg_ParseTuple(args, "s", &address)) return NULL; // Try to retrieve a matching network - int r = loc_database_lookup_from_string(self->db, address, &network); + r = loc_database_lookup_from_string(self->db, address, &network); + if (r) { + // Handle any errors + switch (errno) { + case EINVAL: + PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address); + return NULL; - // We got a network - if (r == 0) { - PyObject* obj = new_network(&NetworkType, network); - loc_network_unref(network); + default: + PyErr_SetFromErrno(PyExc_OSError); + } - return obj; + return NULL; } // Nothing found - if (!errno) + if (!network) Py_RETURN_NONE; - // Handle any errors - switch (errno) { - case EINVAL: - PyErr_Format(PyExc_ValueError, "Invalid IP address: %s", address); - - default: - PyErr_SetFromErrno(PyExc_OSError); - } + // We got a network + PyObject* obj = new_network(&NetworkType, network); + loc_network_unref(network); - return NULL; + return obj; } static PyObject* new_database_enumerator(PyTypeObject* type, struct loc_database_enumerator* enumerator) { @@ -298,14 +299,14 @@ static PyObject* Database_networks_flattened(DatabaseObject *self) { } static PyObject* Database_search_networks(DatabaseObject* self, PyObject* args, PyObject* kwargs) { - char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL }; + const char* kwlist[] = { "country_codes", "asns", "flags", "family", "flatten", NULL }; PyObject* country_codes = NULL; PyObject* asn_list = NULL; int flags = 0; int family = 0; int flatten = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!iip", (char**)kwlist, &PyList_Type, &country_codes, &PyList_Type, &asn_list, &flags, &family, &flatten)) return NULL; @@ -459,11 +460,11 @@ static PyObject* Database_countries(DatabaseObject* self) { } static PyObject* Database_list_bogons(DatabaseObject* self, PyObject* args, PyObject* kwargs) { - char* kwlist[] = { "family", NULL }; + const char* kwlist[] = { "family", NULL }; int family = AF_UNSPEC; // Parse arguments - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &family)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", (char**)kwlist, &family)) return NULL; return Database_iterate_all(self, LOC_DB_ENUMERATE_BOGONS, family, 0); diff --git a/src/python/location/database.py b/src/python/location/database.py index 2c93ed0..c31379c 100644 --- a/src/python/location/database.py +++ b/src/python/location/database.py @@ -1,87 +1,117 @@ """ - A lightweight wrapper around psycopg2. - - Originally part of the Tornado framework. The tornado.database module - is slated for removal in Tornado 3.0, and it is now available separately - as torndb. + A lightweight wrapper around psycopg3. """ +import asyncio import logging -import psycopg2 +import psycopg +import psycopg_pool import time +# Setup logging log = logging.getLogger("location.database") -log.propagate = 1 class Connection(object): - """ - A lightweight wrapper around MySQLdb DB-API connections. + def __init__(self, host, database, user=None, password=None): + # Stores connections assigned to tasks + self.__connections = {} - The main value we provide is wrapping rows in a dict/object so that - columns can be accessed by name. Typical usage:: + # Create a connection pool + self.pool = psycopg_pool.ConnectionPool( + "postgresql://%s:%s@%s/%s" % (user, password, host, database), - db = torndb.Connection("localhost", "mydatabase") - for article in db.query("SELECT * FROM articles"): - print article.title + # Callback to configure any new connections + configure=self.__configure, - Cursors are hidden by the implementation, but other than that, the methods - are very similar to the DB-API. + # Set limits for min/max connections in the pool + min_size=1, + max_size=512, - We explicitly set the timezone to UTC and the character encoding to - UTF-8 on all connections to avoid time zone and encoding errors. - """ - def __init__(self, host, database, user=None, password=None): - self.host = host - self.database = database - - self._db = None - self._db_args = { - "host" : host, - "database" : database, - "user" : user, - "password" : password, - "sslmode" : "require", - } - - try: - self.reconnect() - except Exception: - log.error("Cannot connect to database on %s", self.host, exc_info=True) + # Give clients up to one minute to retrieve a connection + timeout=60, - def __del__(self): - self.close() + # Close connections after they have been idle for a few seconds + max_idle=5, + ) - def close(self): + def __configure(self, conn): """ - Closes this database connection. + Configures any newly opened connections """ - if getattr(self, "_db", None) is not None: - self._db.close() - self._db = None + # Enable autocommit + conn.autocommit = True + + # Return any rows as dicts + conn.row_factory = psycopg.rows.dict_row - def reconnect(self): + def connection(self, *args, **kwargs): """ - Closes the existing database connection and re-opens it. + Returns a connection from the pool """ - self.close() + # Fetch the current task + task = asyncio.current_task() - self._db = psycopg2.connect(**self._db_args) - self._db.autocommit = True + assert task, "Could not determine task" + + # Try returning the same connection to the same task + try: + return self.__connections[task] + except KeyError: + pass + + # Fetch a new connection from the pool + conn = self.__connections[task] = self.pool.getconn(*args, **kwargs) + + log.debug("Assigning database connection %s to %s" % (conn, task)) + + # When the task finishes, release the connection + task.add_done_callback(self.__release_connection) + + return conn + + def __release_connection(self, task): + # Retrieve the connection + try: + conn = self.__connections[task] + except KeyError: + return - # Initialize the timezone setting. - self.execute("SET TIMEZONE TO 'UTC'") + log.debug("Releasing database connection %s of %s" % (conn, task)) + + # Delete it + del self.__connections[task] + + # Return the connection back into the pool + self.pool.putconn(conn) + + def _execute(self, cursor, execute, query, parameters): + # Store the time we started this query + #t = time.monotonic() + + #try: + # log.debug("Running SQL query %s" % (query % parameters)) + #except Exception: + # pass + + # Execute the query + execute(query, parameters) + + # How long did this take? + #elapsed = time.monotonic() - t + + # Log the query time + #log.debug(" Query time: %.2fms" % (elapsed * 1000)) def query(self, query, *parameters, **kwparameters): """ Returns a row list for the given query and parameters. """ - cursor = self._cursor() - try: - self._execute(cursor, query, parameters, kwparameters) - column_names = [d[0] for d in cursor.description] - return [Row(zip(column_names, row)) for row in cursor] - finally: - cursor.close() + conn = self.connection() + + with conn.cursor() as cursor: + self._execute(cursor, cursor.execute, query, parameters or kwparameters) + + return [Row(row) for row in cursor] def get(self, query, *parameters, **kwparameters): """ @@ -97,104 +127,37 @@ class Connection(object): def execute(self, query, *parameters, **kwparameters): """ - Executes the given query, returning the lastrowid from the query. + Executes the given query. """ - return self.execute_lastrowid(query, *parameters, **kwparameters) + conn = self.connection() - def execute_lastrowid(self, query, *parameters, **kwparameters): - """ - Executes the given query, returning the lastrowid from the query. - """ - cursor = self._cursor() - try: - self._execute(cursor, query, parameters, kwparameters) - return cursor.lastrowid - finally: - cursor.close() - - def execute_rowcount(self, query, *parameters, **kwparameters): - """ - Executes the given query, returning the rowcount from the query. - """ - cursor = self._cursor() - try: - self._execute(cursor, query, parameters, kwparameters) - return cursor.rowcount - finally: - cursor.close() + with conn.cursor() as cursor: + self._execute(cursor, cursor.execute, query, parameters or kwparameters) def executemany(self, query, parameters): """ Executes the given query against all the given param sequences. - - We return the lastrowid from the query. - """ - return self.executemany_lastrowid(query, parameters) - - def executemany_lastrowid(self, query, parameters): """ - Executes the given query against all the given param sequences. + conn = self.connection() - We return the lastrowid from the query. - """ - cursor = self._cursor() - try: - cursor.executemany(query, parameters) - return cursor.lastrowid - finally: - cursor.close() + with conn.cursor() as cursor: + self._execute(cursor, cursor.executemany, query, parameters) - def executemany_rowcount(self, query, parameters): + def transaction(self): """ - Executes the given query against all the given param sequences. - - We return the rowcount from the query. + Creates a new transaction on the current tasks' connection """ - cursor = self._cursor() - - try: - cursor.executemany(query, parameters) - return cursor.rowcount - finally: - cursor.close() - - def _ensure_connected(self): - if self._db is None: - log.warning("Database connection was lost...") - - self.reconnect() - - def _cursor(self): - self._ensure_connected() - return self._db.cursor() - - def _execute(self, cursor, query, parameters, kwparameters): - log.debug( - "Executing query: %s" % \ - cursor.mogrify(query, kwparameters or parameters).decode(), - ) - - # Store the time when the query started - t = time.monotonic() - - try: - return cursor.execute(query, kwparameters or parameters) + conn = self.connection() - # Catch any errors - except OperationalError: - log.error("Error connecting to database on %s", self.host) - self.close() - raise + return conn.transaction() - # Log how long the query took - finally: - # Determine duration the query took - d = time.monotonic() - t - - log.debug("Query took %.2fms" % (d * 1000.0)) + def pipeline(self): + """ + Sets the connection into pipeline mode. + """ + conn = self.connection() - def transaction(self): - return Transaction(self) + return conn.pipeline() class Row(dict): @@ -204,24 +167,3 @@ class Row(dict): return self[name] except KeyError: raise AttributeError(name) - - -class Transaction(object): - def __init__(self, db): - self.db = db - - self.db.execute("START TRANSACTION") - - def __enter__(self): - return self - - def __exit__(self, exctype, excvalue, traceback): - if exctype is not None: - self.db.execute("ROLLBACK") - else: - self.db.execute("COMMIT") - - -# Alias some common exceptions -IntegrityError = psycopg2.IntegrityError -OperationalError = psycopg2.OperationalError diff --git a/src/python/location/downloader.py b/src/python/location/downloader.py index 3618968..3dffbc7 100644 --- a/src/python/location/downloader.py +++ b/src/python/location/downloader.py @@ -16,6 +16,7 @@ # # ############################################################################### +import gzip import logging import lzma import os @@ -207,3 +208,56 @@ class Downloader(object): return False return True + + def retrieve(self, url, timeout=None, **kwargs): + """ + This method will fetch the content at the given URL + and will return a file-object to a temporary file. + + If the content was compressed, it will be decompressed on the fly. + """ + # Open a temporary file to buffer the downloaded content + t = tempfile.SpooledTemporaryFile(max_size=100 * 1024 * 1024) + + # Create a new request + req = self._make_request(url, **kwargs) + + # Send request + res = self._send_request(req, timeout=timeout) + + # Write the payload to the temporary file + with res as f: + while True: + buf = f.read(65536) + if not buf: + break + + t.write(buf) + + # Rewind the temporary file + t.seek(0) + + gzip_compressed = False + + # Fetch the content type + content_type = res.headers.get("Content-Type") + + # Decompress any gzipped response on the fly + if content_type in ("application/x-gzip", "application/gzip"): + gzip_compressed = True + + # Check for the gzip magic in case web servers send a different MIME type + elif t.read(2) == b"\x1f\x8b": + gzip_compressed = True + + # Reset again + t.seek(0) + + # Decompress the temporary file + if gzip_compressed: + log.debug("Gzip compression detected") + + t = gzip.GzipFile(fileobj=t, mode="rb") + + # Return the temporary file handle + return t diff --git a/src/python/location/importer.py b/src/python/location/importer.py deleted file mode 100644 index f391e03..0000000 --- a/src/python/location/importer.py +++ /dev/null @@ -1,266 +0,0 @@ -############################################################################### -# # -# libloc - A library to determine the location of someone on the Internet # -# # -# Copyright (C) 2020 IPFire Development Team # -# # -# This library is free software; you can redistribute it and/or # -# modify it under the terms of the GNU Lesser General Public # -# License as published by the Free Software Foundation; either # -# version 2.1 of the License, or (at your option) any later version. # -# # -# This library is distributed in the hope that it will be useful, # -# but WITHOUT ANY WARRANTY; without even the implied warranty of # -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # -# Lesser General Public License for more details. # -# # -############################################################################### - -import gzip -import logging -import tempfile -import urllib.request - -# Initialise logging -log = logging.getLogger("location.importer") -log.propagate = 1 - -WHOIS_SOURCES = { - # African Network Information Centre - "AFRINIC": [ - "https://ftp.afrinic.net/pub/pub/dbase/afrinic.db.gz" - ], - - # Asia Pacific Network Information Centre - "APNIC": [ - "https://ftp.apnic.net/apnic/whois/apnic.db.inet6num.gz", - "https://ftp.apnic.net/apnic/whois/apnic.db.inetnum.gz", - #"https://ftp.apnic.net/apnic/whois/apnic.db.route6.gz", - #"https://ftp.apnic.net/apnic/whois/apnic.db.route.gz", - "https://ftp.apnic.net/apnic/whois/apnic.db.aut-num.gz", - "https://ftp.apnic.net/apnic/whois/apnic.db.organisation.gz" - ], - - # American Registry for Internet Numbers - # XXX there is nothing useful for us in here - # ARIN: [ - # "https://ftp.arin.net/pub/rr/arin.db" - # ], - - # Japan Network Information Center - "JPNIC": [ - "https://ftp.nic.ad.jp/jpirr/jpirr.db.gz" - ], - - # Latin America and Caribbean Network Information Centre - "LACNIC": [ - "https://ftp.lacnic.net/lacnic/dbase/lacnic.db.gz" - ], - - # Réseaux IP Européens - "RIPE": [ - "https://ftp.ripe.net/ripe/dbase/split/ripe.db.inet6num.gz", - "https://ftp.ripe.net/ripe/dbase/split/ripe.db.inetnum.gz", - #"https://ftp.ripe.net/ripe/dbase/split/ripe.db.route6.gz", - #"https://ftp.ripe.net/ripe/dbase/split/ripe.db.route.gz", - "https://ftp.ripe.net/ripe/dbase/split/ripe.db.aut-num.gz", - "https://ftp.ripe.net/ripe/dbase/split/ripe.db.organisation.gz" - ], -} - -EXTENDED_SOURCES = { - # African Network Information Centre - # "ARIN": [ - # "https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest" - # ], - - # Asia Pacific Network Information Centre - # "APNIC": [ - # "https://ftp.apnic.net/apnic/stats/apnic/delegated-apnic-extended-latest" - # ], - - # American Registry for Internet Numbers - "ARIN": [ - "https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest" - ], - - # Latin America and Caribbean Network Information Centre - "LACNIC": [ - "https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-extended-latest" - ], - - # Réseaux IP Européens - # "RIPE": [ - # "https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-extended-latest" - # ], -} - -# List all sources -SOURCES = set(WHOIS_SOURCES|EXTENDED_SOURCES) - -class Downloader(object): - def __init__(self): - self.proxy = None - - def set_proxy(self, url): - """ - Sets a HTTP proxy that is used to perform all requests - """ - log.info("Using proxy %s" % url) - self.proxy = url - - def retrieve(self, url, data=None): - """ - This method will fetch the content at the given URL - and will return a file-object to a temporary file. - - If the content was compressed, it will be decompressed on the fly. - """ - # Open a temporary file to buffer the downloaded content - t = tempfile.SpooledTemporaryFile(max_size=100 * 1024 * 1024) - - # Create a new request - req = urllib.request.Request(url, data=data) - - # Configure proxy - if self.proxy: - req.set_proxy(self.proxy, "http") - - log.info("Retrieving %s..." % req.full_url) - - # Send request - res = urllib.request.urlopen(req) - - # Log the response headers - log.debug("Response Headers:") - for header in res.headers: - log.debug(" %s: %s" % (header, res.headers[header])) - - # Write the payload to the temporary file - with res as f: - while True: - buf = f.read(65536) - if not buf: - break - - t.write(buf) - - # Rewind the temporary file - t.seek(0) - - gzip_compressed = False - - # Fetch the content type - content_type = res.headers.get("Content-Type") - - # Decompress any gzipped response on the fly - if content_type in ("application/x-gzip", "application/gzip"): - gzip_compressed = True - - # Check for the gzip magic in case web servers send a different MIME type - elif t.read(2) == b"\x1f\x8b": - gzip_compressed = True - - # Reset again - t.seek(0) - - # Decompress the temporary file - if gzip_compressed: - log.debug("Gzip compression detected") - - t = gzip.GzipFile(fileobj=t, mode="rb") - - # Return the temporary file handle - return t - - def request_blocks(self, url, data=None): - """ - This method will fetch the data from the URL and return an - iterator for each block in the data. - """ - # Download the data first - t = self.retrieve(url, data=data) - - # Then, split it into blocks - return iterate_over_blocks(t) - - def request_lines(self, url, data=None): - """ - This method will fetch the data from the URL and return an - iterator for each line in the data. - """ - # Download the data first - t = self.retrieve(url, data=data) - - # Then, split it into lines - return iterate_over_lines(t) - - -def read_blocks(f): - for block in iterate_over_blocks(f): - type = None - data = {} - - for i, line in enumerate(block): - key, value = line.split(":", 1) - - # The key of the first line defines the type - if i == 0: - type = key - - # Store value - data[key] = value.strip() - - yield type, data - -def iterate_over_blocks(f, charsets=("utf-8", "latin1")): - block = [] - - for line in f: - # Skip commented lines - if line.startswith(b"#") or line.startswith(b"%"): - continue - - # Convert to string - for charset in charsets: - try: - line = line.decode(charset) - except UnicodeDecodeError: - continue - else: - break - - # Remove any comments at the end of line - line, hash, comment = line.partition("#") - - # Strip any whitespace at the end of the line - line = line.rstrip() - - # If we cut off some comment and the line is empty, we can skip it - if comment and not line: - continue - - # If the line has some content, keep collecting it - if line: - block.append(line) - continue - - # End the block on an empty line - if block: - yield block - - # Reset the block - block = [] - - # Return the last block - if block: - yield block - - -def iterate_over_lines(f): - for line in f: - # Decode the line - line = line.decode() - - # Strip the ending - yield line.rstrip() diff --git a/src/python/network.c b/src/python/network.c index c14174e..637037c 100644 --- a/src/python/network.c +++ b/src/python/network.c @@ -248,6 +248,63 @@ static PyObject* Network_get__last_address(NetworkObject* self) { return PyBytes_FromAddress(address); } +static PyObject* Network_richcompare(NetworkObject* self, PyObject* other, int op) { + int r; + + // Check for type + if (!PyObject_IsInstance(other, (PyObject *)&NetworkType)) + Py_RETURN_NOTIMPLEMENTED; + + NetworkObject* o = (NetworkObject*)other; + + r = loc_network_cmp(self->network, o->network); + + switch (op) { + case Py_EQ: + if (r == 0) + Py_RETURN_TRUE; + + Py_RETURN_FALSE; + + case Py_LT: + if (r < 0) + Py_RETURN_TRUE; + + Py_RETURN_FALSE; + + default: + break; + } + + Py_RETURN_NOTIMPLEMENTED; +} + +static PyObject* Network_reverse_pointer(NetworkObject* self, PyObject* args, PyObject* kwargs) { + const char* kwlist[] = { "suffix", NULL }; + const char* suffix = NULL; + char* rp = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|z", (char**)kwlist, &suffix)) + return NULL; + + rp = loc_network_reverse_pointer(self->network, suffix); + if (!rp) { + switch (errno) { + case ENOTSUP: + Py_RETURN_NONE; + + default: + PyErr_SetFromErrno(PyExc_OSError); + return NULL; + } + } + + PyObject* ret = PyUnicode_FromString(rp); + free(rp); + + return ret; +} + static struct PyMethodDef Network_methods[] = { { "exclude", @@ -267,6 +324,12 @@ static struct PyMethodDef Network_methods[] = { METH_VARARGS, NULL, }, + { + "reverse_pointer", + (PyCFunction)Network_reverse_pointer, + METH_VARARGS|METH_KEYWORDS, + NULL, + }, { "set_flag", (PyCFunction)Network_set_flag, @@ -342,4 +405,5 @@ PyTypeObject NetworkType = { .tp_getset = Network_getsetters, .tp_repr = (reprfunc)Network_repr, .tp_str = (reprfunc)Network_str, + .tp_richcompare = (richcmpfunc)Network_richcompare, }; diff --git a/src/resolv.c b/src/resolv.c index 1c4cd75..64271c4 100644 --- a/src/resolv.c +++ b/src/resolv.c @@ -133,7 +133,7 @@ LOC_EXPORT int loc_discover_latest_version(struct loc_ctx* ctx, } if (!size || (len = *payload) >= size || !len) { - ERROR(ctx, "Broken TXT record (len = %d, size = %d)\n", len, size); + ERROR(ctx, "Broken TXT record (len = %d, size = %u)\n", len, size); return -1; } diff --git a/src/scripts/location-importer.in b/src/scripts/location-importer.in index 28a4f6c..641aec2 100644 --- a/src/scripts/location-importer.in +++ b/src/scripts/location-importer.in @@ -3,7 +3,7 @@ # # # libloc - A library to determine the location of someone on the Internet # # # -# Copyright (C) 2020-2022 IPFire Development Team # +# Copyright (C) 2020-2024 IPFire Development Team # # # # This library is free software; you can redistribute it and/or # # modify it under the terms of the GNU Lesser General Public # @@ -18,8 +18,11 @@ ############################################################################### import argparse -import concurrent.futures +import asyncio +import csv +import functools import http.client +import io import ipaddress import json import logging @@ -32,7 +35,7 @@ import urllib.error # Load our location module import location import location.database -import location.importer +from location.downloader import Downloader from location.i18n import _ # Initialise logging @@ -46,6 +49,21 @@ VALID_ASN_RANGES = ( (131072, 4199999999), ) +TRANSLATED_COUNTRIES = { + # When people say UK, they mean GB + "UK" : "GB", +} + +IGNORED_COUNTRIES = set(( + # Formerly Yugoslavia + "YU", + + # Some people use ZZ to say "no country" or to hide the country + "ZZ", +)) + +# Configure the CSV parser for ARIN +csv.register_dialect("arin", delimiter=",", quoting=csv.QUOTE_ALL, quotechar="\"") class CLI(object): def parse_cli(self): @@ -87,6 +105,8 @@ class CLI(object): # Update WHOIS update_whois = subparsers.add_parser("update-whois", help=_("Update WHOIS Information")) + update_whois.add_argument("sources", nargs="*", + help=_("Only update these sources")) update_whois.set_defaults(func=self.handle_update_whois) # Update announcements @@ -101,6 +121,13 @@ class CLI(object): help=_("Update Geofeeds")) update_geofeeds.set_defaults(func=self.handle_update_geofeeds) + # Update feeds + update_feeds = subparsers.add_parser("update-feeds", + help=_("Update Feeds")) + update_feeds.add_argument("feeds", nargs="*", + help=_("Only update these feeds")) + update_feeds.set_defaults(func=self.handle_update_feeds) + # Update overrides update_overrides = subparsers.add_parser("update-overrides", help=_("Update overrides"), @@ -133,15 +160,18 @@ class CLI(object): return args - def run(self): + async def run(self): # Parse command line arguments args = self.parse_cli() + # Initialize the downloader + self.downloader = Downloader() + # Initialise database self.db = self._setup_database(args) # Call function - ret = args.func(args) + ret = await args.func(args) # Return with exit code if ret: @@ -167,8 +197,10 @@ class CLI(object): first_seen_at timestamp without time zone DEFAULT CURRENT_TIMESTAMP, last_seen_at timestamp without time zone DEFAULT CURRENT_TIMESTAMP); CREATE UNIQUE INDEX IF NOT EXISTS announcements_networks ON announcements(network); - CREATE INDEX IF NOT EXISTS announcements_family ON announcements(family(network)); - CREATE INDEX IF NOT EXISTS announcements_search ON announcements USING GIST(network inet_ops); + CREATE INDEX IF NOT EXISTS announcements_search2 ON announcements + USING SPGIST(network inet_ops); + ALTER TABLE announcements ALTER COLUMN first_seen_at SET NOT NULL; + ALTER TABLE announcements ALTER COLUMN last_seen_at SET NOT NULL; -- autnums CREATE TABLE IF NOT EXISTS autnums(number bigint, name text NOT NULL); @@ -185,8 +217,8 @@ class CLI(object): ALTER TABLE networks ADD COLUMN IF NOT EXISTS original_countries text[]; ALTER TABLE networks ADD COLUMN IF NOT EXISTS source text; CREATE UNIQUE INDEX IF NOT EXISTS networks_network ON networks(network); - CREATE INDEX IF NOT EXISTS networks_family ON networks USING BTREE(family(network)); - CREATE INDEX IF NOT EXISTS networks_search ON networks USING GIST(network inet_ops); + CREATE INDEX IF NOT EXISTS networks_search2 ON networks + USING SPGIST(network inet_ops); -- geofeeds CREATE TABLE IF NOT EXISTS geofeeds( @@ -207,16 +239,39 @@ class CLI(object): ); CREATE INDEX IF NOT EXISTS geofeed_networks_geofeed_id ON geofeed_networks(geofeed_id); - CREATE INDEX IF NOT EXISTS geofeed_networks_search - ON geofeed_networks USING GIST(network inet_ops); CREATE TABLE IF NOT EXISTS network_geofeeds(network inet, url text); - CREATE UNIQUE INDEX IF NOT EXISTS network_geofeeds_unique - ON network_geofeeds(network); - CREATE INDEX IF NOT EXISTS network_geofeeds_search - ON network_geofeeds USING GIST(network inet_ops); + ALTER TABLE network_geofeeds ADD COLUMN IF NOT EXISTS source text NOT NULL; + CREATE UNIQUE INDEX IF NOT EXISTS network_geofeeds_unique2 + ON network_geofeeds(network, url); CREATE INDEX IF NOT EXISTS network_geofeeds_url ON network_geofeeds(url); + -- feeds + CREATE TABLE IF NOT EXISTS autnum_feeds( + number bigint NOT NULL, + source text NOT NULL, + name text, + country text, + is_anonymous_proxy boolean, + is_satellite_provider boolean, + is_anycast boolean, + is_drop boolean + ); + CREATE UNIQUE INDEX IF NOT EXISTS autnum_feeds_unique + ON autnum_feeds(number, source); + + CREATE TABLE IF NOT EXISTS network_feeds( + network inet NOT NULL, + source text NOT NULL, + country text, + is_anonymous_proxy boolean, + is_satellite_provider boolean, + is_anycast boolean, + is_drop boolean + ); + CREATE UNIQUE INDEX IF NOT EXISTS network_feeds_unique + ON network_feeds(network, source); + -- overrides CREATE TABLE IF NOT EXISTS autnum_overrides( number bigint NOT NULL, @@ -228,8 +283,8 @@ class CLI(object): ); CREATE UNIQUE INDEX IF NOT EXISTS autnum_overrides_number ON autnum_overrides(number); - ALTER TABLE autnum_overrides ADD COLUMN IF NOT EXISTS source text; ALTER TABLE autnum_overrides ADD COLUMN IF NOT EXISTS is_drop boolean; + ALTER TABLE autnum_overrides DROP COLUMN IF EXISTS source; CREATE TABLE IF NOT EXISTS network_overrides( network inet NOT NULL, @@ -240,15 +295,34 @@ class CLI(object): ); CREATE UNIQUE INDEX IF NOT EXISTS network_overrides_network ON network_overrides(network); - CREATE INDEX IF NOT EXISTS network_overrides_search - ON network_overrides USING GIST(network inet_ops); - ALTER TABLE network_overrides ADD COLUMN IF NOT EXISTS source text; ALTER TABLE network_overrides ADD COLUMN IF NOT EXISTS is_drop boolean; + ALTER TABLE network_overrides DROP COLUMN IF EXISTS source; + + -- Cleanup things we no longer need + DROP TABLE IF EXISTS geofeed_overrides; + DROP INDEX IF EXISTS announcements_family; + DROP INDEX IF EXISTS announcements_search; + DROP INDEX IF EXISTS geofeed_networks_search; + DROP INDEX IF EXISTS networks_family; + DROP INDEX IF EXISTS networks_search; + DROP INDEX IF EXISTS network_feeds_search; + DROP INDEX IF EXISTS network_geofeeds_unique; + DROP INDEX IF EXISTS network_geofeeds_search; + DROP INDEX IF EXISTS network_overrides_search; """) return db - def handle_write(self, ns): + def fetch_countries(self): + """ + Returns a list of all countries on the list + """ + # Fetch all valid country codes to check parsed networks aganist... + countries = self.db.query("SELECT country_code FROM countries ORDER BY country_code") + + return set((country.country_code for country in countries)) + + async def handle_write(self, ns): """ Compiles a database in libloc format out of what is in the database """ @@ -265,6 +339,9 @@ class CLI(object): if ns.license: writer.license = ns.license + # Analyze everything for the query planner hopefully making better decisions + self.db.execute("ANALYZE") + # Add all Autonomous Systems log.info("Writing Autonomous Systems...") @@ -295,143 +372,420 @@ class CLI(object): # Add all networks log.info("Writing networks...") - # Select all known networks - rows = self.db.query(""" - WITH known_networks AS ( - SELECT network FROM announcements - UNION - SELECT network FROM networks - UNION - SELECT network FROM network_overrides - UNION - SELECT network FROM geofeed_networks - ), + # Create a new temporary table where we collect + # the networks that we are interested in + self.db.execute(""" + CREATE TEMPORARY TABLE + n + ( + network inet NOT NULL, + autnum integer, + country text, + is_anonymous_proxy boolean, + is_satellite_provider boolean, + is_anycast boolean, + is_drop boolean + ) + WITH (FILLFACTOR = 50) + """) + + # Add all known networks + self.db.execute(""" + INSERT INTO + n + ( + network + ) + + SELECT + network + FROM + announcements + + UNION + + SELECT + network + FROM + networks - ordered_networks AS ( + UNION + + SELECT + network + FROM + network_feeds + + UNION + + SELECT + network + FROM + network_overrides + + UNION + + SELECT + network + FROM + geofeed_networks + """) + + # Create an index to search through networks faster + self.db.execute(""" + CREATE INDEX + n_search + ON + n + USING + SPGIST(network) + """) + + # Analyze n + self.db.execute("ANALYZE n") + + # Apply the AS number to all networks + self.db.execute(""" + -- Join all networks together with their most specific announcements + WITH announcements AS ( SELECT - known_networks.network AS network, - announcements.autnum AS autnum, - networks.country AS country, + n.network, + announcements.autnum, - -- Must be part of returned values for ORDER BY clause - masklen(announcements.network) AS sort_a, - masklen(networks.network) AS sort_b + -- Sort all merges and number them so + -- that we can later select the best one + ROW_NUMBER() + OVER + ( + PARTITION BY + n.network + ORDER BY + masklen(announcements.network) DESC + ) AS row FROM - known_networks - LEFT JOIN - announcements ON known_networks.network <<= announcements.network - LEFT JOIN - networks ON known_networks.network <<= networks.network - ORDER BY - known_networks.network, - sort_a DESC, - sort_b DESC + n + JOIN + announcements + ON + announcements.network >>= n.network ) - -- Return a list of those networks enriched with all - -- other information that we store in the database - SELECT - DISTINCT ON (network) - network, - autnum, + -- Store the result + UPDATE + n + SET + autnum = announcements.autnum + FROM + announcements + WHERE + announcements.network = n.network + AND + announcements.row = 1 + """, + ) - -- Country - COALESCE( - ( - SELECT country FROM network_overrides overrides - WHERE networks.network <<= overrides.network - ORDER BY masklen(overrides.network) DESC - LIMIT 1 - ), - ( - SELECT country FROM autnum_overrides overrides - WHERE networks.autnum = overrides.number - ), + # Apply country information + self.db.execute(""" + WITH networks AS ( + SELECT + n.network, + networks.country, + + ROW_NUMBER() + OVER ( - SELECT - geofeed_networks.country AS country - FROM - network_geofeeds + PARTITION BY + n.network + ORDER BY + masklen(networks.network) DESC + ) AS row + FROM + n + JOIN + networks + ON + networks.network >>= n.network + ) + + UPDATE + n + SET + country = networks.country + FROM + networks + WHERE + networks.network = n.network + AND + networks.row = 1 + """, + ) - -- Join the data from the geofeeds - LEFT JOIN - geofeeds ON network_geofeeds.url = geofeeds.url - LEFT JOIN - geofeed_networks ON geofeeds.id = geofeed_networks.geofeed_id + # Add all country information from Geofeeds + self.db.execute(""" + WITH geofeeds AS ( + SELECT + DISTINCT ON (geofeed_networks.network) + geofeed_networks.network, + geofeed_networks.country + FROM + geofeeds + JOIN + network_geofeeds networks + ON + geofeeds.url = networks.url + JOIN + geofeed_networks + ON + geofeeds.id = geofeed_networks.geofeed_id + AND + networks.network >>= geofeed_networks.network + ), - -- Check whether we have a geofeed for this network - WHERE - networks.network <<= network_geofeeds.network - AND - networks.network <<= geofeed_networks.network + networks AS ( + SELECT + n.network, + geofeeds.country, - -- Filter for the best result + ROW_NUMBER() + OVER + ( + PARTITION BY + n.network ORDER BY - masklen(geofeed_networks.network) DESC - LIMIT 1 - ), - networks.country - ) AS country, + masklen(geofeeds.network) DESC + ) AS row + FROM + n + JOIN + geofeeds + ON + geofeeds.network >>= n.network + ) - -- Flags - COALESCE( - ( - SELECT is_anonymous_proxy FROM network_overrides overrides - WHERE networks.network <<= overrides.network - ORDER BY masklen(overrides.network) DESC - LIMIT 1 - ), - ( - SELECT is_anonymous_proxy FROM autnum_overrides overrides - WHERE networks.autnum = overrides.number - ), - FALSE - ) AS is_anonymous_proxy, - COALESCE( - ( - SELECT is_satellite_provider FROM network_overrides overrides - WHERE networks.network <<= overrides.network - ORDER BY masklen(overrides.network) DESC - LIMIT 1 - ), - ( - SELECT is_satellite_provider FROM autnum_overrides overrides - WHERE networks.autnum = overrides.number - ), - FALSE - ) AS is_satellite_provider, - COALESCE( - ( - SELECT is_anycast FROM network_overrides overrides - WHERE networks.network <<= overrides.network - ORDER BY masklen(overrides.network) DESC - LIMIT 1 - ), - ( - SELECT is_anycast FROM autnum_overrides overrides - WHERE networks.autnum = overrides.number - ), - FALSE - ) AS is_anycast, - COALESCE( + UPDATE + n + SET + country = networks.country + FROM + networks + WHERE + networks.network = n.network + AND + networks.row = 1 + """, + ) + + # Apply country and flags from feeds + self.db.execute(""" + WITH networks AS ( + SELECT + n.network, + network_feeds.country, + + -- Flags + network_feeds.is_anonymous_proxy, + network_feeds.is_satellite_provider, + network_feeds.is_anycast, + network_feeds.is_drop, + + ROW_NUMBER() + OVER ( - SELECT is_drop FROM network_overrides overrides - WHERE networks.network <<= overrides.network - ORDER BY masklen(overrides.network) DESC - LIMIT 1 - ), + PARTITION BY + n.network + ORDER BY + masklen(network_feeds.network) DESC + ) AS row + FROM + n + JOIN + network_feeds + ON + network_feeds.network >>= n.network + ) + + UPDATE + n + SET + country = + COALESCE(networks.country, n.country), + + is_anonymous_proxy = + COALESCE(networks.is_anonymous_proxy, n.is_anonymous_proxy), + + is_satellite_provider = + COALESCE(networks.is_satellite_provider, n.is_satellite_provider), + + is_anycast = + COALESCE(networks.is_anycast, n.is_anycast), + + is_drop = + COALESCE(networks.is_drop, n.is_drop) + FROM + networks + WHERE + networks.network = n.network + AND + networks.row = 1 + """, + ) + + # Apply country and flags from AS feeds + self.db.execute(""" + WITH networks AS ( + SELECT + n.network, + autnum_feeds.country, + + -- Flags + autnum_feeds.is_anonymous_proxy, + autnum_feeds.is_satellite_provider, + autnum_feeds.is_anycast, + autnum_feeds.is_drop + FROM + n + JOIN + autnum_feeds + ON + autnum_feeds.number = n.autnum + ) + + UPDATE + n + SET + country = + COALESCE(networks.country, n.country), + + is_anonymous_proxy = + COALESCE(networks.is_anonymous_proxy, n.is_anonymous_proxy), + + is_satellite_provider = + COALESCE(networks.is_satellite_provider, n.is_satellite_provider), + + is_anycast = + COALESCE(networks.is_anycast, n.is_anycast), + + is_drop = + COALESCE(networks.is_drop, n.is_drop) + FROM + networks + WHERE + networks.network = n.network + """) + + # Apply network overrides + self.db.execute(""" + WITH networks AS ( + SELECT + n.network, + network_overrides.country, + + -- Flags + network_overrides.is_anonymous_proxy, + network_overrides.is_satellite_provider, + network_overrides.is_anycast, + network_overrides.is_drop, + + ROW_NUMBER() + OVER ( - SELECT is_drop FROM autnum_overrides overrides - WHERE networks.autnum = overrides.number - ), - FALSE - ) AS is_drop + PARTITION BY + n.network + ORDER BY + masklen(network_overrides.network) DESC + ) AS row + FROM + n + JOIN + network_overrides + ON + network_overrides.network >>= n.network + ) + + UPDATE + n + SET + country = + COALESCE(networks.country, n.country), + + is_anonymous_proxy = + COALESCE(networks.is_anonymous_proxy, n.is_anonymous_proxy), + + is_satellite_provider = + COALESCE(networks.is_satellite_provider, n.is_satellite_provider), + + is_anycast = + COALESCE(networks.is_anycast, n.is_anycast), + + is_drop = + COALESCE(networks.is_drop, n.is_drop) FROM - ordered_networks networks + networks + WHERE + networks.network = n.network + AND + networks.row = 1 + """) + + # Apply AS overrides + self.db.execute(""" + WITH networks AS ( + SELECT + n.network, + autnum_overrides.country, + + -- Flags + autnum_overrides.is_anonymous_proxy, + autnum_overrides.is_satellite_provider, + autnum_overrides.is_anycast, + autnum_overrides.is_drop + FROM + n + JOIN + autnum_overrides + ON + autnum_overrides.number = n.autnum + ) + + UPDATE + n + SET + country = + COALESCE(networks.country, n.country), + + is_anonymous_proxy = + COALESCE(networks.is_anonymous_proxy, n.is_anonymous_proxy), + + is_satellite_provider = + COALESCE(networks.is_satellite_provider, n.is_satellite_provider), + + is_anycast = + COALESCE(networks.is_anycast, n.is_anycast), + + is_drop = + COALESCE(networks.is_drop, n.is_drop) + FROM + networks + WHERE + networks.network = n.network + """) + + # Here we could remove some networks that we no longer need, but since we + # already have implemented our deduplication/merge algorithm this would not + # be necessary. + + # Export the entire temporary table + rows = self.db.query(""" + SELECT + * + FROM + n + ORDER BY + network """) for row in rows: - network = writer.add_network(row.network) + network = writer.add_network("%s" % row.network) # Save country if row.country: @@ -456,7 +810,17 @@ class CLI(object): # Add all countries log.info("Writing countries...") - rows = self.db.query("SELECT * FROM countries ORDER BY country_code") + + # Select all countries + rows = self.db.query(""" + SELECT + * + FROM + countries + ORDER BY + country_code + """, + ) for row in rows: c = writer.add_country(row.country_code) @@ -468,182 +832,333 @@ class CLI(object): for file in ns.file: writer.write(file) - def handle_update_whois(self, ns): - downloader = location.importer.Downloader() - + async def handle_update_whois(self, ns): # Did we run successfully? - error = False + success = True + + sources = ( + # African Network Information Centre + ("AFRINIC", ( + (self._import_standard_format, "https://ftp.afrinic.net/pub/pub/dbase/afrinic.db.gz"), + )), + + # Asia Pacific Network Information Centre + ("APNIC", ( + (self._import_standard_format, "https://ftp.apnic.net/apnic/whois/apnic.db.inet6num.gz"), + (self._import_standard_format, "https://ftp.apnic.net/apnic/whois/apnic.db.inetnum.gz"), + (self._import_standard_format, "https://ftp.apnic.net/apnic/whois/apnic.db.aut-num.gz"), + (self._import_standard_format, "https://ftp.apnic.net/apnic/whois/apnic.db.organisation.gz"), + )), + + # American Registry for Internet Numbers + ("ARIN", ( + (self._import_extended_format, "https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest"), + (self._import_arin_as_names, "https://ftp.arin.net/pub/resource_registry_service/asns.csv"), + )), + + # Japan Network Information Center + ("JPNIC", ( + (self._import_standard_format, "https://ftp.nic.ad.jp/jpirr/jpirr.db.gz"), + )), + + # Latin America and Caribbean Network Information Centre + ("LACNIC", ( + (self._import_standard_format, "https://ftp.lacnic.net/lacnic/dbase/lacnic.db.gz"), + (self._import_extended_format, "https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-extended-latest"), + )), + + # Réseaux IP Européens + ("RIPE", ( + (self._import_standard_format, "https://ftp.ripe.net/ripe/dbase/split/ripe.db.inet6num.gz"), + (self._import_standard_format, "https://ftp.ripe.net/ripe/dbase/split/ripe.db.inetnum.gz"), + (self._import_standard_format, "https://ftp.ripe.net/ripe/dbase/split/ripe.db.aut-num.gz"), + (self._import_standard_format, "https://ftp.ripe.net/ripe/dbase/split/ripe.db.organisation.gz"), + )), + ) - # Fetch all valid country codes to check parsed networks aganist - validcountries = self.countries + # Fetch all valid country codes to check parsed networks against + countries = self.fetch_countries() + + # Check if we have countries + if not countries: + log.error("Please import countries before importing any WHOIS data") + return 1 # Iterate over all potential sources - for source in sorted(location.importer.SOURCES): - with self.db.transaction(): - # Create some temporary tables to store parsed data - self.db.execute(""" - CREATE TEMPORARY TABLE _autnums(number integer NOT NULL, - organization text NOT NULL, source text NOT NULL) ON COMMIT DROP; - CREATE UNIQUE INDEX _autnums_number ON _autnums(number); + for name, feeds in sources: + # Skip anything that should not be updated + if ns.sources and not name in ns.sources: + continue - CREATE TEMPORARY TABLE _organizations(handle text NOT NULL, - name text NOT NULL, source text NOT NULL) ON COMMIT DROP; - CREATE UNIQUE INDEX _organizations_handle ON _organizations(handle); + try: + await self._process_source(name, feeds, countries) - CREATE TEMPORARY TABLE _rirdata(network inet NOT NULL, country text NOT NULL, - original_countries text[] NOT NULL, source text NOT NULL) - ON COMMIT DROP; - CREATE INDEX _rirdata_search ON _rirdata - USING BTREE(family(network), masklen(network)); - CREATE UNIQUE INDEX _rirdata_network ON _rirdata(network); - """) + # Log an error but continue if an exception occurs + except Exception as e: + log.error("Error processing source %s" % name, exc_info=True) + success = False - # Remove all previously imported content - self.db.execute("DELETE FROM networks WHERE source = %s", source) + # Return a non-zero exit code for errors + return 0 if success else 1 - try: - # Fetch WHOIS sources - for url in location.importer.WHOIS_SOURCES.get(source, []): - for block in downloader.request_blocks(url): - self._parse_block(block, source, validcountries) - - # Fetch extended sources - for url in location.importer.EXTENDED_SOURCES.get(source, []): - for line in downloader.request_lines(url): - self._parse_line(line, source, validcountries) - except urllib.error.URLError as e: - log.error("Could not retrieve data from %s: %s" % (source, e)) - error = True - - # Continue with the next source - continue + async def _process_source(self, source, feeds, countries): + """ + This function processes one source + """ + # Wrap everything into one large transaction + with self.db.transaction(): + # Remove all previously imported content + self.db.execute("DELETE FROM autnums WHERE source = %s", source) + self.db.execute("DELETE FROM networks WHERE source = %s", source) + self.db.execute("DELETE FROM network_geofeeds WHERE source = %s", source) + + # Create some temporary tables to store parsed data + self.db.execute(""" + CREATE TEMPORARY TABLE _autnums(number integer NOT NULL, + organization text NOT NULL, source text NOT NULL) ON COMMIT DROP; + CREATE UNIQUE INDEX _autnums_number ON _autnums(number); + + CREATE TEMPORARY TABLE _organizations(handle text NOT NULL, + name text NOT NULL, source text NOT NULL) ON COMMIT DROP; + CREATE UNIQUE INDEX _organizations_handle ON _organizations(handle); + + CREATE TEMPORARY TABLE _rirdata(network inet NOT NULL, country text, + original_countries text[] NOT NULL, source text NOT NULL) + ON COMMIT DROP; + CREATE INDEX _rirdata_search ON _rirdata + USING BTREE(family(network), masklen(network)); + CREATE UNIQUE INDEX _rirdata_network ON _rirdata(network); + """) - # Process all parsed networks from every RIR we happen to have access to, - # insert the largest network chunks into the networks table immediately... - families = self.db.query("SELECT DISTINCT family(network) AS family FROM _rirdata \ - ORDER BY family(network)") + # Parse all feeds + for callback, url, *args in feeds: + # Retrieve the feed + f = self.downloader.retrieve(url) - for family in (row.family for row in families): - # Fetch the smallest mask length in our data set - smallest = self.db.get(""" - SELECT - MIN( - masklen(network) - ) AS prefix - FROM - _rirdata - WHERE - family(network) = %s""", - family, + # Call the callback + with self.db.pipeline(): + await callback(source, countries, f, *args) + + # Process all parsed networks from every RIR we happen to have access to, + # insert the largest network chunks into the networks table immediately... + families = self.db.query(""" + SELECT DISTINCT + family(network) AS family + FROM + _rirdata + ORDER BY + family(network) + """, + ) + + for family in (row.family for row in families): + # Fetch the smallest mask length in our data set + smallest = self.db.get(""" + SELECT + MIN( + masklen(network) + ) AS prefix + FROM + _rirdata + WHERE + family(network) = %s + """, family, + ) + + # Copy all networks + self.db.execute(""" + INSERT INTO + networks + ( + network, + country, + original_countries, + source ) + SELECT + network, + country, + original_countries, + source + FROM + _rirdata + WHERE + masklen(network) = %s + AND + family(network) = %s + ON CONFLICT DO + NOTHING""", + smallest.prefix, + family, + ) + + # ... determine any other prefixes for this network family, ... + prefixes = self.db.query(""" + SELECT + DISTINCT masklen(network) AS prefix + FROM + _rirdata + WHERE + family(network) = %s + ORDER BY + masklen(network) ASC + OFFSET 1 + """, family, + ) - # Copy all networks + # ... and insert networks with this prefix in case they provide additional + # information (i. e. subnet of a larger chunk with a different country) + for prefix in (row.prefix for row in prefixes): self.db.execute(""" - INSERT INTO - networks - ( - network, - country, - original_countries, - source + WITH candidates AS ( + SELECT + _rirdata.network, + _rirdata.country, + _rirdata.original_countries, + _rirdata.source + FROM + _rirdata + WHERE + family(_rirdata.network) = %s + AND + masklen(_rirdata.network) = %s + ), + filtered AS ( + SELECT + DISTINCT ON (c.network) + c.network, + c.country, + c.original_countries, + c.source, + masklen(networks.network), + networks.country AS parent_country + FROM + candidates c + LEFT JOIN + networks + ON + c.network << networks.network + ORDER BY + c.network, + masklen(networks.network) DESC NULLS LAST ) + INSERT INTO + networks(network, country, original_countries, source) SELECT network, country, original_countries, source FROM - _rirdata + filtered WHERE - masklen(network) = %s - AND - family(network) = %s - ON CONFLICT DO - NOTHING""", - smallest.prefix, - family, + parent_country IS NULL + OR + country <> parent_country + ON CONFLICT DO NOTHING + """, family, prefix, ) - # ... determine any other prefixes for this network family, ... - prefixes = self.db.query(""" - SELECT - DISTINCT masklen(network) AS prefix - FROM - _rirdata - WHERE - family(network) = %s - ORDER BY - masklen(network) ASC - OFFSET 1""", - family, - ) + self.db.execute(""" + INSERT INTO + autnums + ( + number, + name, + source + ) + SELECT + _autnums.number, + _organizations.name, + _organizations.source + FROM + _autnums + JOIN + _organizations ON _autnums.organization = _organizations.handle + ON CONFLICT + ( + number + ) + DO UPDATE + SET name = excluded.name + """, + ) - # ... and insert networks with this prefix in case they provide additional - # information (i. e. subnet of a larger chunk with a different country) - for prefix in (row.prefix for row in prefixes): - self.db.execute(""" - WITH candidates AS ( - SELECT - _rirdata.network, - _rirdata.country, - _rirdata.original_countries, - _rirdata.source - FROM - _rirdata - WHERE - family(_rirdata.network) = %s - AND - masklen(_rirdata.network) = %s - ), - filtered AS ( - SELECT - DISTINCT ON (c.network) - c.network, - c.country, - c.original_countries, - c.source, - masklen(networks.network), - networks.country AS parent_country - FROM - candidates c - LEFT JOIN - networks - ON - c.network << networks.network - ORDER BY - c.network, - masklen(networks.network) DESC NULLS LAST - ) - INSERT INTO - networks(network, country, original_countries, source) - SELECT - network, - country, - original_countries, - source - FROM - filtered - WHERE - parent_country IS NULL - OR - country <> parent_country - ON CONFLICT DO NOTHING""", - family, prefix, - ) + async def _import_standard_format(self, source, countries, f, *args): + """ + Imports a single standard format source feed + """ + # Iterate over all blocks + for block in iterate_over_blocks(f): + self._parse_block(block, source, countries) + + async def _import_extended_format(self, source, countries, f, *args): + # Iterate over all lines + for line in iterate_over_lines(f): + self._parse_line(line, source, countries) + + async def _import_arin_as_names(self, source, countries, f, *args): + # Wrap the data to text + f = io.TextIOWrapper(f) + + # Walk through the file + for line in csv.DictReader(f, dialect="arin"): + # Fetch status + status = line.get("Status") + + # We are only interested in anything managed by ARIN + if not status == "Full Registry Services": + continue - self.db.execute(""" - INSERT INTO autnums(number, name, source) - SELECT _autnums.number, _organizations.name, _organizations.source FROM _autnums - JOIN _organizations ON _autnums.organization = _organizations.handle - ON CONFLICT (number) DO UPDATE SET name = excluded.name; - """) + # Fetch organization name + name = line.get("Org Name") - # Download and import (technical) AS names from ARIN - with self.db.transaction(): - self._import_as_names_from_arin() + # Extract ASNs + first_asn = line.get("Start AS Number") + last_asn = line.get("End AS Number") - # Return a non-zero exit code for errors - return 1 if error else 0 + # Cast to a number + try: + first_asn = int(first_asn) + except TypeError as e: + log.warning("Could not parse ASN '%s'" % first_asn) + continue + + try: + last_asn = int(last_asn) + except TypeError as e: + log.warning("Could not parse ASN '%s'" % last_asn) + continue + + # Check if the range is valid + if last_asn < first_asn: + log.warning("Invalid ASN range %s-%s" % (first_asn, last_asn)) + + # Insert everything into the database + for asn in range(first_asn, last_asn + 1): + if not self._check_parsed_asn(asn): + log.warning("Skipping invalid ASN %s" % asn) + continue + + self.db.execute(""" + INSERT INTO + autnums + ( + number, + name, + source + ) + VALUES + ( + %s, %s, %s + ) + ON CONFLICT + ( + number + ) + DO NOTHING + """, asn, name, "ARIN", + ) def _check_parsed_network(self, network): """ @@ -655,9 +1170,6 @@ class CLI(object): (b) covering a too large chunk of the IP address space (prefix length is < 7 for IPv4 networks, and < 10 for IPv6) (c) "0.0.0.0" or "::" as a network address - (d) are too small for being publicly announced (we have decided not to - process them at the moment, as they significantly enlarge our - database without providing very helpful additional information) This unfortunately is necessary due to brain-dead clutter across various RIR databases, causing mismatches and eventually disruptions. @@ -665,45 +1177,36 @@ class CLI(object): We will return False in case a network is not suitable for adding it to our database, and True otherwise. """ + # Check input + if isinstance(network, ipaddress.IPv6Network): + pass + elif isinstance(network, ipaddress.IPv4Network): + pass + else: + raise ValueError("Invalid network: %s (type %s)" % (network, type(network))) - if not network or not (isinstance(network, ipaddress.IPv4Network) or isinstance(network, ipaddress.IPv6Network)): - return False - + # Ignore anything that isn't globally routable if not network.is_global: log.debug("Skipping non-globally routable network: %s" % network) return False - if network.version == 4: - if network.prefixlen < 7: - log.debug("Skipping too big IP chunk: %s" % network) - return False - - if network.prefixlen > 24: - log.debug("Skipping network too small to be publicly announced: %s" % network) - return False - - if str(network.network_address) == "0.0.0.0": - log.debug("Skipping network based on 0.0.0.0: %s" % network) - return False + # Ignore anything that is unspecified IP range (See RFC 5735 for IPv4 or RFC 2373 for IPv6) + elif network.is_unspecified: + log.debug("Skipping unspecified network: %s" % network) + return False - elif network.version == 6: + # IPv6 + if network.version == 6: if network.prefixlen < 10: log.debug("Skipping too big IP chunk: %s" % network) return False - if network.prefixlen > 48: - log.debug("Skipping network too small to be publicly announced: %s" % network) - return False - - if str(network.network_address) == "::": - log.debug("Skipping network based on '::': %s" % network) + # IPv4 + elif network.version == 4: + if network.prefixlen < 7: + log.debug("Skipping too big IP chunk: %s" % network) return False - else: - # This should not happen... - log.warning("Skipping network of unknown family, this should not happen: %s" % network) - return False - # In case we have made it here, the network is considered to # be suitable for libloc consumption... return True @@ -721,7 +1224,30 @@ class CLI(object): log.info("Supplied ASN %s out of publicly routable ASN ranges" % asn) return False - def _parse_block(self, block, source_key, validcountries = None): + def _check_geofeed_url(self, url): + """ + This function checks if a Geofeed URL is valid. + + If so, it returns the normalized URL which should be stored instead of + the original one. + """ + # Parse the URL + try: + url = urllib.parse.urlparse(url) + except ValueError as e: + log.warning("Invalid URL %s: %s" % (url, e)) + return + + # Make sure that this is a HTTPS URL + if not url.scheme == "https": + log.warning("Skipping Geofeed URL that is not using HTTPS: %s" \ + % url.geturl()) + return + + # Normalize the URL and convert it back + return url.geturl() + + def _parse_block(self, block, source_key, countries): # Get first line to find out what type of block this is line = block[0] @@ -731,7 +1257,7 @@ class CLI(object): # inetnum if line.startswith("inet6num:") or line.startswith("inetnum:"): - return self._parse_inetnum_block(block, source_key, validcountries) + return self._parse_inetnum_block(block, source_key, countries) # organisation elif line.startswith("organisation:"): @@ -784,13 +1310,9 @@ class CLI(object): autnum.get("asn"), autnum.get("org"), source_key, ) - def _parse_inetnum_block(self, block, source_key, validcountries = None): - log.debug("Parsing inetnum block:") - + def _parse_inetnum_block(self, block, source_key, countries): inetnum = {} for line in block: - log.debug(line) - # Split line key, val = split_line(line) @@ -850,21 +1372,28 @@ class CLI(object): inetnum[key] = [ipaddress.ip_network(val, strict=False)] elif key == "country": - val = val.upper() + cc = val.upper() - # Catch RIR data objects with more than one country code... - if not key in inetnum: - inetnum[key] = [] - else: - if val in inetnum.get("country"): - # ... but keep this list distinct... - continue + # Ignore certain country codes + if cc in IGNORED_COUNTRIES: + log.debug("Ignoring country code '%s'" % cc) + continue - # When people set country codes to "UK", they actually mean "GB" - if val == "UK": - val = "GB" + # Translate country codes + try: + cc = TRANSLATED_COUNTRIES[cc] + except KeyError: + pass - inetnum[key].append(val) + # Do we know this country? + if not cc in countries: + log.warning("Skipping invalid country code '%s'" % cc) + continue + + try: + inetnum[key].append(cc) + except KeyError: + inetnum[key] = [cc] # Parse the geofeed attribute elif key == "geofeed": @@ -877,63 +1406,72 @@ class CLI(object): inetnum["geofeed"] = m.group(1) # Skip empty objects - if not inetnum or not "country" in inetnum: + if not inetnum: return - # Prepare skipping objects with unknown country codes... - invalidcountries = [singlecountry for singlecountry in inetnum.get("country") if singlecountry not in validcountries] - # Iterate through all networks enumerated from above, check them for plausibility and insert # them into the database, if _check_parsed_network() succeeded for single_network in inetnum.get("inet6num") or inetnum.get("inetnum"): - if self._check_parsed_network(single_network): - # Skip objects with unknown country codes if they are valid to avoid log spam... - if validcountries and invalidcountries: - log.warning("Skipping network with bogus countr(y|ies) %s (original countries: %s): %s" % \ - (invalidcountries, inetnum.get("country"), inetnum.get("inet6num") or inetnum.get("inetnum"))) - break + if not self._check_parsed_network(single_network): + continue - # Everything is fine here, run INSERT statement... - self.db.execute("INSERT INTO _rirdata(network, country, original_countries, source) \ - VALUES(%s, %s, %s, %s) ON CONFLICT (network) DO UPDATE SET country = excluded.country", - "%s" % single_network, inetnum.get("country")[0], inetnum.get("country"), source_key, + # Fetch the countries or use a list with an empty country + countries = inetnum.get("country", [None]) + + # Insert the network into the database but only use the first country code + for cc in countries: + self.db.execute(""" + INSERT INTO + _rirdata + ( + network, + country, + original_countries, + source + ) + VALUES + ( + %s, %s, %s, %s + ) + ON CONFLICT (network) + DO UPDATE SET country = excluded.country + """, "%s" % single_network, cc, [cc for cc in countries if cc], source_key, ) - # Update any geofeed information - geofeed = inetnum.get("geofeed", None) - if geofeed: - self._parse_geofeed(geofeed, single_network) - - # Delete any previous geofeeds - else: - self.db.execute("DELETE FROM network_geofeeds WHERE network = %s", - "%s" % single_network) + # If there are more than one country, we will only use the first one + break - def _parse_geofeed(self, url, single_network): - # Parse the URL - url = urllib.parse.urlparse(url) + # Update any geofeed information + geofeed = inetnum.get("geofeed", None) + if geofeed: + self._parse_geofeed(source_key, geofeed, single_network) - # Make sure that this is a HTTPS URL - if not url.scheme == "https": - log.debug("Geofeed URL is not using HTTPS: %s" % geofeed) + def _parse_geofeed(self, source, url, single_network): + # Check the URL + url = self._check_geofeed_url(url) + if not url: return - # Put the URL back together normalized - url = url.geturl() - # Store/update any geofeeds self.db.execute(""" INSERT INTO - network_geofeeds( - network, - url - ) - VALUES( - %s, %s + network_geofeeds + ( + network, + url, + source ) - ON CONFLICT (network) DO - UPDATE SET url = excluded.url""", - "%s" % single_network, url, + VALUES + ( + %s, %s, %s + ) + ON CONFLICT + ( + network, url + ) + DO UPDATE SET + source = excluded.source + """, "%s" % single_network, url, source, ) def _parse_org_block(self, block, source_key): @@ -957,7 +1495,7 @@ class CLI(object): org.get("organisation"), org.get("org-name"), source_key, ) - def _parse_line(self, line, source_key, validcountries = None): + def _parse_line(self, line, source_key, validcountries=None): # Skip version line if line.startswith("2"): return @@ -972,6 +1510,15 @@ class CLI(object): log.warning("Could not parse line: %s" % line) return + # Skip ASN + if type == "asn": + return + + # Skip any unknown protocols + elif not type in ("ipv6", "ipv4"): + log.warning("Unknown IP protocol '%s'" % type) + return + # Skip any lines that are for stats only or do not have a country # code at all (avoids log spam below) if not country_code or country_code == '*': @@ -983,10 +1530,6 @@ class CLI(object): (country_code, line)) return - if type in ("ipv6", "ipv4"): - return self._parse_ip_line(country_code, type, line, source_key) - - def _parse_ip_line(self, country, type, line, source_key): try: address, prefix, date, status, organization = line.split("|") except ValueError: @@ -1024,66 +1567,30 @@ class CLI(object): if not self._check_parsed_network(network): return - self.db.execute("INSERT INTO networks(network, country, original_countries, source) \ - VALUES(%s, %s, %s, %s) ON CONFLICT (network) DO \ - UPDATE SET country = excluded.country", - "%s" % network, country, [country], source_key, - ) - - def _import_as_names_from_arin(self): - downloader = location.importer.Downloader() - - # XXX: Download AS names file from ARIN (note that these names appear to be quite - # technical, not intended for human consumption, as description fields in - # organisation handles for other RIRs are - however, this is what we have got, - # and in some cases, it might be still better than nothing) - for line in downloader.request_lines("https://ftp.arin.net/info/asn.txt"): - # Valid lines start with a space, followed by the number of the Autonomous System ... - if not line.startswith(" "): - continue - - # Split line and check if there is a valid ASN in it... - asn, name = line.split()[0:2] - - try: - asn = int(asn) - except ValueError: - log.debug("Skipping ARIN AS names line not containing an integer for ASN") - continue - - # Filter invalid ASNs... - if not self._check_parsed_asn(asn): - continue - - # Skip any AS name that appears to be a placeholder for a different RIR or entity... - if re.match(r"^(ASN-BLK|)(AFCONC|AFRINIC|APNIC|ASNBLK|LACNIC|RIPE|IANA)(?:\d?$|\-)", name): - continue - - # Bail out in case the AS name contains anything we do not expect here... - if re.search(r"[^a-zA-Z0-9-_]", name): - log.debug("Skipping ARIN AS name for %s containing invalid characters: %s" % \ - (asn, name)) - - # Things look good here, run INSERT statement and skip this one if we already have - # a (better?) name for this Autonomous System... - self.db.execute(""" - INSERT INTO autnums( - number, - name, - source - ) VALUES (%s, %s, %s) - ON CONFLICT (number) DO NOTHING""", - asn, - name, - "ARIN", + self.db.execute(""" + INSERT INTO + networks + ( + network, + country, + original_countries, + source ) + VALUES + ( + %s, %s, %s, %s + ) + ON CONFLICT (network) + DO UPDATE SET country = excluded.country + """, "%s" % network, country_code, [country_code], source_key, + ) - def handle_update_announcements(self, ns): + async def handle_update_announcements(self, ns): server = ns.server[0] with self.db.transaction(): if server.startswith("/"): - self._handle_update_announcements_from_bird(server) + await self._handle_update_announcements_from_bird(server) # Purge anything we never want here self.db.execute(""" @@ -1142,7 +1649,7 @@ class CLI(object): DELETE FROM announcements WHERE last_seen_at <= CURRENT_TIMESTAMP - INTERVAL '14 days'; """) - def _handle_update_announcements_from_bird(self, server): + async def _handle_update_announcements_from_bird(self, server): # Pre-compile the regular expression for faster searching route = re.compile(b"^\s(.+?)\s+.+?\[(?:AS(.*?))?.\]$") @@ -1171,11 +1678,25 @@ class CLI(object): # Fetch the extracted network and ASN network, autnum = m.groups() + # Skip the line if there is no network + if not network: + continue + # Decode into strings - if network: - network = network.decode() - if autnum: - autnum = autnum.decode() + network = network.decode() + + # Parse as network object + network = ipaddress.ip_network(network) + + # Skip announcements that are too large + if isinstance(network, ipaddress.IPv6Network): + if network.prefixlen < 10: + log.warning("Skipping unusually large network %s" % network) + continue + elif isinstance(network, ipaddress.IPv4Network): + if network.prefixlen < 4: + log.warning("Skipping unusually large network %s" % network) + continue # Collect all aggregated networks if not autnum: @@ -1183,11 +1704,14 @@ class CLI(object): aggregated_networks.append(network) continue + # Decode ASN + autnum = autnum.decode() + # Insert it into the database self.db.execute("INSERT INTO announcements(network, autnum) \ VALUES(%s, %s) ON CONFLICT (network) DO \ UPDATE SET autnum = excluded.autnum, last_seen_at = CURRENT_TIMESTAMP", - network, autnum, + "%s" % network, autnum, ) # Process any aggregated networks @@ -1260,7 +1784,7 @@ class CLI(object): # Otherwise return the line yield line - def handle_update_geofeeds(self, ns): + async def handle_update_geofeeds(self, ns): # Sync geofeeds with self.db.transaction(): # Delete all geofeeds which are no longer linked @@ -1268,26 +1792,32 @@ class CLI(object): DELETE FROM geofeeds WHERE - NOT EXISTS ( + geofeeds.url NOT IN ( SELECT - 1 + network_geofeeds.url FROM network_geofeeds - WHERE - geofeeds.url = network_geofeeds.url - )""", + ) + """, ) # Copy all geofeeds self.db.execute(""" + WITH all_geofeeds AS ( + SELECT + network_geofeeds.url + FROM + network_geofeeds + ) INSERT INTO - geofeeds( - url - ) + geofeeds + ( + url + ) SELECT url FROM - network_geofeeds + all_geofeeds ON CONFLICT (url) DO NOTHING """, @@ -1308,12 +1838,14 @@ class CLI(object): id """) - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - results = executor.map(self._fetch_geofeed, geofeeds) + ratelimiter = asyncio.Semaphore(32) - # Fetch all results to raise any exceptions - for result in results: - pass + # Update all geofeeds + async with asyncio.TaskGroup() as tasks: + for geofeed in geofeeds: + task = tasks.create_task( + self._fetch_geofeed(ratelimiter, geofeed), + ) # Delete data from any feeds that did not update in the last two weeks with self.db.transaction(): @@ -1333,23 +1865,32 @@ class CLI(object): ) """) - def _fetch_geofeed(self, geofeed): - log.debug("Fetching Geofeed %s" % geofeed.url) + async def _fetch_geofeed(self, ratelimiter, geofeed): + async with ratelimiter: + log.debug("Fetching Geofeed %s" % geofeed.url) - with self.db.transaction(): - # Open the URL - try: - req = urllib.request.Request(geofeed.url, headers={ - "User-Agent" : "location/%s" % location.__version__, + with self.db.transaction(): + # Open the URL + try: + # Send the request + f = await asyncio.to_thread( + self.downloader.retrieve, - # We expect some plain text file in CSV format - "Accept" : "text/csv, text/plain", - }) + # Fetch the feed by its URL + geofeed.url, - # XXX set proxy + # Send some extra headers + headers={ + "User-Agent" : "location/%s" % location.__version__, + + # We expect some plain text file in CSV format + "Accept" : "text/csv, text/plain", + }, + + # Don't wait longer than 10 seconds for a response + timeout=10, + ) - # Send the request - with urllib.request.urlopen(req, timeout=10) as f: # Remove any previous data self.db.execute("DELETE FROM geofeed_networks \ WHERE geofeed_id = %s", geofeed.id) @@ -1357,134 +1898,132 @@ class CLI(object): lineno = 0 # Read the output line by line - for line in f: - lineno += 1 + with self.db.pipeline(): + for line in f: + lineno += 1 - try: - line = line.decode() + try: + line = line.decode() - # Ignore any lines we cannot decode - except UnicodeDecodeError: - log.debug("Could not decode line %s in %s" \ - % (lineno, geofeed.url)) - continue + # Ignore any lines we cannot decode + except UnicodeDecodeError: + log.debug("Could not decode line %s in %s" \ + % (lineno, geofeed.url)) + continue - # Strip any newline - line = line.rstrip() + # Strip any newline + line = line.rstrip() - # Skip empty lines - if not line: - continue + # Skip empty lines + if not line: + continue - # Try to parse the line - try: - fields = line.split(",", 5) - except ValueError: - log.debug("Could not parse line: %s" % line) - continue + # Skip comments + elif line.startswith("#"): + continue - # Check if we have enough fields - if len(fields) < 4: - log.debug("Not enough fields in line: %s" % line) - continue + # Try to parse the line + try: + fields = line.split(",", 5) + except ValueError: + log.debug("Could not parse line: %s" % line) + continue - # Fetch all fields - network, country, region, city, = fields[:4] + # Check if we have enough fields + if len(fields) < 4: + log.debug("Not enough fields in line: %s" % line) + continue - # Try to parse the network - try: - network = ipaddress.ip_network(network, strict=False) - except ValueError: - log.debug("Could not parse network: %s" % network) - continue - - # Strip any excess whitespace from country codes - country = country.strip() - - # Make the country code uppercase - country = country.upper() - - # Check the country code - if not country: - log.debug("Empty country code in Geofeed %s line %s" \ - % (geofeed.url, lineno)) - continue - - elif not location.country_code_is_valid(country): - log.debug("Invalid country code in Geofeed %s:%s: %s" \ - % (geofeed.url, lineno, country)) - continue - - # Write this into the database - self.db.execute(""" - INSERT INTO - geofeed_networks ( - geofeed_id, - network, - country, - region, - city - ) - VALUES (%s, %s, %s, %s, %s)""", - geofeed.id, - "%s" % network, - country, - region, - city, - ) + # Fetch all fields + network, country, region, city, = fields[:4] - # Catch any HTTP errors - except urllib.request.HTTPError as e: - self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ - WHERE id = %s", e.code, "%s" % e, geofeed.id) + # Try to parse the network + try: + network = ipaddress.ip_network(network, strict=False) + except ValueError: + log.debug("Could not parse network: %s" % network) + continue - # Remove any previous data when the feed has been deleted - if e.code == 404: - self.db.execute("DELETE FROM geofeed_networks \ - WHERE geofeed_id = %s", geofeed.id) + # Strip any excess whitespace from country codes + country = country.strip() - # Catch any other errors and connection timeouts - except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e: - log.debug("Could not fetch URL %s: %s" % (geofeed.url, e)) + # Make the country code uppercase + country = country.upper() - self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ - WHERE id = %s", 599, "%s" % e, geofeed.id) + # Check the country code + if not country: + log.debug("Empty country code in Geofeed %s line %s" \ + % (geofeed.url, lineno)) + continue - # Mark the geofeed as updated - else: - self.db.execute(""" - UPDATE - geofeeds - SET - updated_at = CURRENT_TIMESTAMP, - status = NULL, - error = NULL - WHERE - id = %s""", - geofeed.id, - ) + elif not location.country_code_is_valid(country): + log.debug("Invalid country code in Geofeed %s:%s: %s" \ + % (geofeed.url, lineno, country)) + continue - def handle_update_overrides(self, ns): - with self.db.transaction(): - # Only drop manually created overrides, as we can be reasonably sure to have them, - # and preserve the rest. If appropriate, it is deleted by correspondent functions. - self.db.execute(""" - DELETE FROM autnum_overrides WHERE source = 'manual'; - DELETE FROM network_overrides WHERE source = 'manual'; - """) + # Write this into the database + self.db.execute(""" + INSERT INTO + geofeed_networks ( + geofeed_id, + network, + country, + region, + city + ) + VALUES (%s, %s, %s, %s, %s)""", + geofeed.id, + "%s" % network, + country, + region, + city, + ) + + # Catch any HTTP errors + except urllib.request.HTTPError as e: + self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ + WHERE id = %s", e.code, "%s" % e, geofeed.id) + + # Remove any previous data when the feed has been deleted + if e.code == 404: + self.db.execute("DELETE FROM geofeed_networks \ + WHERE geofeed_id = %s", geofeed.id) + + # Catch any other errors and connection timeouts + except (http.client.InvalidURL, http.client.RemoteDisconnected, urllib.request.URLError, TimeoutError) as e: + log.debug("Could not fetch URL %s: %s" % (geofeed.url, e)) + + self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ + WHERE id = %s", 599, "%s" % e, geofeed.id) + + # Mark the geofeed as updated + else: + self.db.execute(""" + UPDATE + geofeeds + SET + updated_at = CURRENT_TIMESTAMP, + status = NULL, + error = NULL + WHERE + id = %s""", + geofeed.id, + ) - # Update overrides for various cloud providers big enough to publish their own IP - # network allocation lists in a machine-readable format... - self._update_overrides_for_aws() + async def handle_update_overrides(self, ns): + with self.db.transaction(): + # Drop any previous content + self.db.execute("TRUNCATE TABLE autnum_overrides") + self.db.execute("TRUNCATE TABLE network_overrides") - # Update overrides for Spamhaus DROP feeds... - self._update_overrides_for_spamhaus_drop() + # Remove all Geofeeds + self.db.execute("DELETE FROM network_geofeeds WHERE source = %s", "overrides") for file in ns.files: log.info("Reading %s..." % file) with open(file, "rb") as f: - for type, block in location.importer.read_blocks(f): + for type, block in read_blocks(f): if type == "net": network = block.get("net") # Try to parse and normalise the network @@ -1500,19 +2039,24 @@ class CLI(object): continue self.db.execute(""" - INSERT INTO network_overrides( + INSERT INTO + network_overrides + ( network, country, - source, is_anonymous_proxy, is_satellite_provider, is_anycast, is_drop - ) VALUES (%s, %s, %s, %s, %s, %s, %s) - ON CONFLICT (network) DO NOTHING""", + ) + VALUES + ( + %s, %s, %s, %s, %s, %s + ) + ON CONFLICT (network) DO NOTHING + """, "%s" % network, block.get("country"), - "manual", self._parse_bool(block, "is-anonymous-proxy"), self._parse_bool(block, "is-satellite-provider"), self._parse_bool(block, "is-anycast"), @@ -1531,269 +2075,390 @@ class CLI(object): autnum = autnum[2:] self.db.execute(""" - INSERT INTO autnum_overrides( + INSERT INTO + autnum_overrides + ( number, name, country, - source, is_anonymous_proxy, is_satellite_provider, is_anycast, is_drop - ) VALUES(%s, %s, %s, %s, %s, %s, %s, %s) - ON CONFLICT DO NOTHING""", + ) + VALUES + ( + %s, %s, %s, %s, %s, %s, %s + ) + ON CONFLICT (number) DO NOTHING + """, autnum, block.get("name"), block.get("country"), - "manual", self._parse_bool(block, "is-anonymous-proxy"), self._parse_bool(block, "is-satellite-provider"), self._parse_bool(block, "is-anycast"), self._parse_bool(block, "drop"), ) + # Geofeeds + elif type == "geofeed": + networks = [] + + # Fetch the URL + url = block.get("geofeed") + + # Fetch permitted networks + for n in block.get("network", []): + try: + n = ipaddress.ip_network(n) + except ValueError as e: + log.warning("Ignoring invalid network %s: %s" % (n, e)) + continue + + networks.append(n) + + # If no networks have been specified, permit for everything + if not networks: + networks = [ + ipaddress.ip_network("::/0"), + ipaddress.ip_network("0.0.0.0/0"), + ] + + # Check the URL + url = self._check_geofeed_url(url) + if not url: + continue + + # Store the Geofeed URL + self.db.execute(""" + INSERT INTO + geofeeds + ( + url + ) + VALUES + ( + %s + ) + ON CONFLICT (url) DO NOTHING + """, url, + ) + + # Store all permitted networks + self.db.executemany(""" + INSERT INTO + network_geofeeds + ( + network, + url, + source + ) + VALUES + ( + %s, %s, %s + ) + ON CONFLICT + ( + network, url + ) + DO UPDATE SET + source = excluded.source + """, (("%s" % n, url, "overrides") for n in networks), + ) + else: log.warning("Unsupported type: %s" % type) - def _update_overrides_for_aws(self): - # Download Amazon AWS IP allocation file to create overrides... - downloader = location.importer.Downloader() + async def handle_update_feeds(self, ns): + """ + Update any third-party feeds + """ + success = True + + feeds = ( + # AWS IP Ranges + ("AWS-IP-RANGES", self._import_aws_ip_ranges, "https://ip-ranges.amazonaws.com/ip-ranges.json"), - try: - # Fetch IP ranges - f = downloader.retrieve("https://ip-ranges.amazonaws.com/ip-ranges.json") + # Spamhaus DROP + ("SPAMHAUS-DROP", self._import_spamhaus_drop, "https://www.spamhaus.org/drop/drop.txt"), + ("SPAMHAUS-DROPV6", self._import_spamhaus_drop, "https://www.spamhaus.org/drop/dropv6.txt"), - # Parse downloaded file - aws_ip_dump = json.load(f) - except Exception as e: - log.error("unable to preprocess Amazon AWS IP ranges: %s" % e) - return + # Spamhaus ASNDROP + ("SPAMHAUS-ASNDROP", self._import_spamhaus_asndrop, "https://www.spamhaus.org/drop/asndrop.json"), + ) - # At this point, we can assume the downloaded file to be valid - self.db.execute(""" - DELETE FROM network_overrides WHERE source = 'Amazon AWS IP feed'; - """) + # Drop any data from feeds that we don't support (any more) + with self.db.transaction(): + # Fetch the names of all feeds we support + sources = [name for name, *rest in feeds] + + self.db.execute("DELETE FROM autnum_feeds WHERE NOT source = ANY(%s)", sources) + self.db.execute("DELETE FROM network_feeds WHERE NOT source = ANY(%s)", sources) - # XXX: Set up a dictionary for mapping a region name to a country. Unfortunately, + # Walk through all feeds + for name, callback, url, *args in feeds: + # Skip any feeds that were not requested on the command line + if ns.feeds and not name in ns.feeds: + continue + + try: + await self._process_feed(name, callback, url, *args) + + # Log an error but continue if an exception occurs + except Exception as e: + log.error("Error processing feed '%s': %s" % (name, e)) + success = False + + # Return status + return 0 if success else 1 + + async def _process_feed(self, name, callback, url, *args): + """ + Processes one feed + """ + # Open the URL + f = self.downloader.retrieve(url) + + with self.db.transaction(): + # Drop any previous content + self.db.execute("DELETE FROM autnum_feeds WHERE source = %s", name) + self.db.execute("DELETE FROM network_feeds WHERE source = %s", name) + + # Call the callback to process the feed + with self.db.pipeline(): + return await callback(name, f, *args) + + async def _import_aws_ip_ranges(self, name, f): + # Parse the feed + feed = json.load(f) + + # Set up a dictionary for mapping a region name to a country. Unfortunately, # there seems to be no machine-readable version available of this other than # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/using-regions-availability-zones.html # (worse, it seems to be incomplete :-/ ); https://www.cloudping.cloud/endpoints # was helpful here as well. aws_region_country_map = { - "af-south-1": "ZA", - "ap-east-1": "HK", - "ap-south-1": "IN", - "ap-south-2": "IN", - "ap-northeast-3": "JP", - "ap-northeast-2": "KR", - "ap-southeast-1": "SG", - "ap-southeast-2": "AU", - "ap-southeast-3": "MY", - "ap-southeast-4": "AU", - "ap-southeast-5": "NZ", # Auckland, NZ - "ap-southeast-6": "AP", # XXX: Precise location not documented anywhere - "ap-northeast-1": "JP", - "ca-central-1": "CA", - "ca-west-1": "CA", - "eu-central-1": "DE", - "eu-central-2": "CH", - "eu-west-1": "IE", - "eu-west-2": "GB", - "eu-south-1": "IT", - "eu-south-2": "ES", - "eu-west-3": "FR", - "eu-north-1": "SE", - "il-central-1": "IL", # XXX: This one is not documented anywhere except for ip-ranges.json itself - "me-central-1": "AE", - "me-south-1": "BH", - "sa-east-1": "BR" - } + # Africa + "af-south-1" : "ZA", + + # Asia + "il-central-1" : "IL", # Tel Aviv + + # Asia/Pacific + "ap-northeast-1" : "JP", + "ap-northeast-2" : "KR", + "ap-northeast-3" : "JP", + "ap-east-1" : "HK", + "ap-south-1" : "IN", + "ap-south-2" : "IN", + "ap-southeast-1" : "SG", + "ap-southeast-2" : "AU", + "ap-southeast-3" : "MY", + "ap-southeast-4" : "AU", # Melbourne + "ap-southeast-5" : "MY", # Malaysia + "ap-southeast-6" : "AP", # XXX: Precise location not documented anywhere + "ap-southeast-7" : "AP", # XXX: Precise location unknown + + # Canada + "ca-central-1" : "CA", + "ca-west-1" : "CA", + + # Europe + "eu-central-1" : "DE", + "eu-central-2" : "CH", + "eu-north-1" : "SE", + "eu-west-1" : "IE", + "eu-west-2" : "GB", + "eu-west-3" : "FR", + "eu-south-1" : "IT", + "eu-south-2" : "ES", + + # Middle East + "me-central-1" : "AE", + "me-south-1" : "BH", + + # Mexico + "mx-central-1" : "MX", + + # South America + "sa-east-1" : "BR", + + # Undocumented, likely located in Berlin rather than Frankfurt + "eusc-de-east-1" : "DE", + } + + # Collect a list of all networks + prefixes = feed.get("ipv6_prefixes", []) + feed.get("prefixes", []) + + for prefix in prefixes: + # Fetch network + network = prefix.get("ipv6_prefix") or prefix.get("ip_prefix") + + # Parse the network + try: + network = ipaddress.ip_network(network) + except ValuleError as e: + log.warning("%s: Unable to parse prefix %s" % (name, network)) + continue - # Fetch all valid country codes to check parsed networks aganist... - rows = self.db.query("SELECT * FROM countries ORDER BY country_code") - validcountries = [] + # Sanitize parsed networks... + if not self._check_parsed_network(network): + continue - for row in rows: - validcountries.append(row.country_code) + # Fetch the region + region = prefix.get("region") - with self.db.transaction(): - for snetwork in aws_ip_dump["prefixes"] + aws_ip_dump["ipv6_prefixes"]: - try: - network = ipaddress.ip_network(snetwork.get("ip_prefix") or snetwork.get("ipv6_prefix"), strict=False) - except ValueError: - log.warning("Unable to parse line: %s" % snetwork) - continue + # Set some defaults + cc = None + is_anycast = False - # Sanitize parsed networks... - if not self._check_parsed_network(network): - continue + # Fetch the CC from the dictionary + try: + cc = aws_region_country_map[region] - # Determine region of this network... - region = snetwork["region"] - cc = None - is_anycast = False + # If we couldn't find anything, let's try something else... + except KeyError as e: + # Find anycast networks + if region == "GLOBAL": + is_anycast = True - # Any region name starting with "us-" will get "US" country code assigned straight away... - if region.startswith("us-"): + # Everything that starts with us- is probably in the United States + elif region.startswith("us-"): cc = "US" + + # Everything that starts with cn- is probably China elif region.startswith("cn-"): - # ... same goes for China ... cc = "CN" - elif region == "GLOBAL": - # ... funny region name for anycast-like networks ... - is_anycast = True - elif region in aws_region_country_map: - # ... assign looked up country code otherwise ... - cc = aws_region_country_map[region] + + # Log a warning for anything else else: - # ... and bail out if we are missing something here - log.warning("Unable to determine country code for line: %s" % snetwork) + log.warning("%s: Could not determine country code for AWS region %s" \ + % (name, region)) continue - # Skip networks with unknown country codes - if not is_anycast and validcountries and cc not in validcountries: - log.warning("Skipping Amazon AWS network with bogus country '%s': %s" % \ - (cc, network)) - return - - # Conduct SQL statement... - self.db.execute(""" - INSERT INTO network_overrides( - network, - country, - source, - is_anonymous_proxy, - is_satellite_provider, - is_anycast - ) VALUES (%s, %s, %s, %s, %s, %s) - ON CONFLICT (network) DO NOTHING""", - "%s" % network, - cc, - "Amazon AWS IP feed", - None, - None, - is_anycast, + # Write to database + self.db.execute(""" + INSERT INTO + network_feeds + ( + network, + source, + country, + is_anycast ) + VALUES + ( + %s, %s, %s, %s + ) + ON CONFLICT (network, source) DO NOTHING + """, "%s" % network, name, cc, is_anycast, + ) + async def _import_spamhaus_drop(self, name, f): + """ + Import Spamhaus DROP IP feeds + """ + # Count all lines + lines = 0 - def _update_overrides_for_spamhaus_drop(self): - downloader = location.importer.Downloader() - - ip_lists = [ - ("SPAMHAUS-DROP", "https://www.spamhaus.org/drop/drop.txt"), - ("SPAMHAUS-EDROP", "https://www.spamhaus.org/drop/edrop.txt"), - ("SPAMHAUS-DROPV6", "https://www.spamhaus.org/drop/dropv6.txt") - ] - - asn_lists = [ - ("SPAMHAUS-ASNDROP", "https://www.spamhaus.org/drop/asndrop.txt") - ] - - for name, url in ip_lists: - # Fetch IP list from given URL - f = downloader.retrieve(url) - - # Split into lines - fcontent = f.readlines() + # Walk through all lines + for line in f: + # Decode line + line = line.decode("utf-8") - with self.db.transaction(): - # Conduct a very basic sanity check to rule out CDN issues causing bogus DROP - # downloads. - if len(fcontent) > 10: - self.db.execute("DELETE FROM network_overrides WHERE source = %s", name) - else: - log.warning("%s (%s) returned likely bogus file, ignored" % (name, url)) - continue + # Strip off any comments + line, _, comment = line.partition(";") - # Iterate through every line, filter comments and add remaining networks to - # the override table in case they are valid... - for sline in fcontent: - # The response is assumed to be encoded in UTF-8... - sline = sline.decode("utf-8") + # Ignore empty lines + if not line: + continue - # Comments start with a semicolon... - if sline.startswith(";"): - continue + # Strip any excess whitespace + line = line.strip() - # Extract network and ignore anything afterwards... - try: - network = ipaddress.ip_network(sline.split()[0], strict=False) - except ValueError: - log.error("Unable to parse line: %s" % sline) - continue + # Increment line counter + lines += 1 - # Sanitize parsed networks... - if not self._check_parsed_network(network): - log.warning("Skipping bogus network found in %s (%s): %s" % \ - (name, url, network)) - continue + # Parse the network + try: + network = ipaddress.ip_network(line) + except ValueError as e: + log.warning("%s: Could not parse network: %s - %s" % (name, line, e)) + continue - # Conduct SQL statement... - self.db.execute(""" - INSERT INTO network_overrides( - network, - source, - is_drop - ) VALUES (%s, %s, %s) - ON CONFLICT (network) DO UPDATE SET is_drop = True""", - "%s" % network, - name, - True - ) + # Check network + if not self._check_parsed_network(network): + log.warning("%s: Skipping bogus network: %s" % (name, network)) + continue - for name, url in asn_lists: - # Fetch URL - f = downloader.retrieve(url) + # Insert into the database + self.db.execute(""" + INSERT INTO + network_feeds + ( + network, + source, + is_drop + ) + VALUES + ( + %s, %s, %s + )""", "%s" % network, name, True, + ) - # Split into lines - fcontent = f.readlines() + # Raise an exception if we could not import anything + if not lines: + raise RuntimeError("Received bogus feed %s with no data" % name) - with self.db.transaction(): - # Conduct a very basic sanity check to rule out CDN issues causing bogus DROP - # downloads. - if len(fcontent) > 10: - self.db.execute("DELETE FROM autnum_overrides WHERE source = %s", name) - else: - log.warning("%s (%s) returned likely bogus file, ignored" % (name, url)) - continue + async def _import_spamhaus_asndrop(self, name, f): + """ + Import Spamhaus ASNDROP feed + """ + for line in f: + # Decode the line + line = line.decode("utf-8") - # Iterate through every line, filter comments and add remaining ASNs to - # the override table in case they are valid... - for sline in f.readlines(): - # The response is assumed to be encoded in UTF-8... - sline = sline.decode("utf-8") + # Parse JSON + try: + line = json.loads(line) + except json.JSONDecodeError as e: + log.warning("%s: Unable to parse JSON object %s: %s" % (name, line, e)) + continue - # Comments start with a semicolon... - if sline.startswith(";"): - continue + # Fetch type + type = line.get("type") - # Throw away anything after the first space... - sline = sline.split()[0] + # Skip any metadata + if type == "metadata": + continue - # ... strip the "AS" prefix from it ... - sline = sline.strip("AS") + # Fetch ASN + asn = line.get("asn") - # ... and convert it into an integer. Voila. - asn = int(sline) + # Skip any lines without an ASN + if not asn: + continue - # Filter invalid ASNs... - if not self._check_parsed_asn(asn): - log.warning("Skipping bogus ASN found in %s (%s): %s" % \ - (name, url, asn)) - continue + # Filter invalid ASNs + if not self._check_parsed_asn(asn): + log.warning("%s: Skipping bogus ASN %s" % (name, asn)) + continue - # Conduct SQL statement... - self.db.execute(""" - INSERT INTO autnum_overrides( - number, - source, - is_drop - ) VALUES (%s, %s, %s) - ON CONFLICT (number) DO UPDATE SET is_drop = True""", - "%s" % asn, - name, - True - ) + # Write to database + self.db.execute(""" + INSERT INTO + autnum_feeds + ( + number, + source, + is_drop + ) + VALUES + ( + %s, %s, %s + )""", "%s" % asn, name, True, + ) @staticmethod def _parse_bool(block, key): @@ -1817,15 +2482,7 @@ class CLI(object): # Default to None return None - @property - def countries(self): - # Fetch all valid country codes to check parsed networks aganist - rows = self.db.query("SELECT * FROM countries ORDER BY country_code") - - # Return all countries - return [row.country_code for row in rows] - - def handle_import_countries(self, ns): + async def handle_import_countries(self, ns): with self.db.transaction(): # Drop all data that we have self.db.execute("TRUNCATE TABLE countries") @@ -1857,9 +2514,89 @@ def split_line(line): return key, val -def main(): +def read_blocks(f): + for block in iterate_over_blocks(f): + type = None + data = {} + + for i, line in enumerate(block): + key, value = line.split(":", 1) + + # The key of the first line defines the type + if i == 0: + type = key + + # Strip any excess whitespace + value = value.strip() + + # Store some values as a list + if type == "geofeed" and key == "network": + try: + data[key].append(value) + except KeyError: + data[key] = [value] + + # Otherwise store the value as string + else: + data[key] = value + + yield type, data + +def iterate_over_blocks(f, charsets=("utf-8", "latin1")): + block = [] + + for line in f: + # Skip commented lines + if line.startswith(b"#") or line.startswith(b"%"): + continue + + # Convert to string + for charset in charsets: + try: + line = line.decode(charset) + except UnicodeDecodeError: + continue + else: + break + + # Remove any comments at the end of line + line, hash, comment = line.partition("#") + + # Strip any whitespace at the end of the line + line = line.rstrip() + + # If we cut off some comment and the line is empty, we can skip it + if comment and not line: + continue + + # If the line has some content, keep collecting it + if line: + block.append(line) + continue + + # End the block on an empty line + if block: + yield block + + # Reset the block + block = [] + + # Return the last block + if block: + yield block + +def iterate_over_lines(f): + for line in f: + # Decode the line + line = line.decode() + + # Strip the ending + yield line.rstrip() + +async def main(): # Run the command line interface c = CLI() - c.run() -main() + await c.run() + +asyncio.run(main()) diff --git a/src/stringpool.c b/src/stringpool.c index 9986a61..f2f55e0 100644 --- a/src/stringpool.c +++ b/src/stringpool.c @@ -127,7 +127,7 @@ int loc_stringpool_open(struct loc_ctx* ctx, struct loc_stringpool** pool, p->data = data; p->length = length; - DEBUG(p->ctx, "Opened string pool at %p (%zu bytes)\n", p->data, p->length); + DEBUG(p->ctx, "Opened string pool at %p (%zd bytes)\n", p->data, p->length); *pool = p; return 0; diff --git a/src/test-address.c b/src/test-address.c index 7012e41..a13af77 100644 --- a/src/test-address.c +++ b/src/test-address.c @@ -54,7 +54,7 @@ static int perform_tests(struct loc_ctx* ctx, const int family) { for (unsigned int i = 0; i < 100; i++) { s = loc_address_str(&address); - printf("Iteration %d: %s\n", i, s); + printf("Iteration %u: %s\n", i, s); if (strcmp(s, e) != 0) { fprintf(stderr, "IP address was formatted in an invalid format: %s\n", s); diff --git a/src/test-as.c b/src/test-as.c index b135c6b..07861d4 100644 --- a/src/test-as.c +++ b/src/test-as.c @@ -85,7 +85,7 @@ int main(int argc, char** argv) { for (unsigned int i = 1; i <= 10; i++) { err = loc_database_get_as(db, &as, i); if (err) { - fprintf(stderr, "Could not find AS%d\n", i); + fprintf(stderr, "Could not find AS%u\n", i); exit(EXIT_FAILURE); } @@ -110,7 +110,7 @@ int main(int argc, char** argv) { } while (as) { - printf("Found AS%d: %s\n", loc_as_get_number(as), loc_as_get_name(as)); + printf("Found AS%u: %s\n", loc_as_get_number(as), loc_as_get_name(as)); err = loc_database_enumerator_next_as(enumerator, &as); if (err) { diff --git a/src/test-database.c b/src/test-database.c index 8ba558a..bfa218f 100644 --- a/src/test-database.c +++ b/src/test-database.c @@ -46,7 +46,7 @@ const char* networks[] = { NULL, }; -static int attempt_to_open(struct loc_ctx* ctx, char* path) { +static int attempt_to_open(struct loc_ctx* ctx, const char* path) { FILE* f = fopen(path, "r"); if (!f) return -1; diff --git a/src/test-network.c b/src/test-network.c index 717ad3a..0cac1a4 100644 --- a/src/test-network.c +++ b/src/test-network.c @@ -29,6 +29,67 @@ #include #include +static int test_reverse_pointers(struct loc_ctx* ctx) { + const struct test { + const char* network; + const char* rp; + } tests[] = { + // IPv6 + { "::1/128", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa." }, + { "2001:db8::/32", "*.8.b.d.0.1.0.0.2.ip6.arpa." }, + + // IPv4 + { "10.0.0.0/32", "0.0.0.10.in-addr.arpa." }, + { "10.0.0.0/24", "*.0.0.10.in-addr.arpa." }, + { "10.0.0.0/16", "*.0.10.in-addr.arpa." }, + { "10.0.0.0/8", "*.10.in-addr.arpa." }, + { "10.0.0.0/0", "*.in-addr.arpa." }, + { "10.0.0.0/1", NULL, }, + { NULL, NULL }, + }; + + struct loc_network* network = NULL; + char* rp = NULL; + int r; + + for (const struct test* test = tests; test->network; test++) { + // Create a new network + r = loc_network_new_from_string(ctx, &network, test->network); + if (r) + return r; + + // Fetch the reverse pointer + rp = loc_network_reverse_pointer(network, NULL); + + // No RP expected and got none + if (!test->rp && !rp) + continue; + + // Got a result when expecting none + else if (!test->rp && rp) { + fprintf(stderr, "Got an RP for %s when expecting none\n", test->network); + return EXIT_FAILURE; + + // Got nothing when expecting a result + } else if (test->rp && !rp) { + fprintf(stderr, "Got no RP for %s when expecting one\n", test->network); + return EXIT_FAILURE; + + // Compare values + } else if (strcmp(test->rp, rp) != 0) { + fprintf(stderr, "Got an unexpected RP for %s: Got %s, expected %s\n", + test->network, rp, test->rp); + return EXIT_FAILURE; + } + + loc_network_unref(network); + if (rp) + free(rp); + } + + return 0; +} + int main(int argc, char** argv) { int err; @@ -298,7 +359,7 @@ int main(int argc, char** argv) { // Lookup an address outside the subnet err = loc_database_lookup_from_string(db, "2001:db8:fffe:1::", &network1); - if (err == 0) { + if (err || network1) { fprintf(stderr, "Could look up 2001:db8:fffe:1::, but I shouldn't\n"); exit(EXIT_FAILURE); } @@ -308,9 +369,9 @@ int main(int argc, char** argv) { unsigned int bit_length; } bit_length_tests[] = { { "::/0", 0 }, - { "2001::/128", 126 }, - { "1.0.0.0/32", 25 }, - { "0.0.0.1/32", 1 }, + { "2001::/128", 16 }, + { "1.0.0.0/32", 8 }, + { "0.0.0.1/32", 32 }, { "255.255.255.255/32", 32 }, { NULL, 0, }, }; @@ -327,7 +388,7 @@ int main(int argc, char** argv) { unsigned int bit_length = loc_address_bit_length(addr); if (bit_length != t->bit_length) { - printf("Bit length of %s didn't match: %u != %u\n", + printf("Bit length of %s didn't match: expected %u, got %u\n", t->network, t->bit_length, bit_length); loc_network_unref(network1); exit(EXIT_FAILURE); @@ -336,6 +397,11 @@ int main(int argc, char** argv) { loc_network_unref(network1); } + // Test reverse pointers + err = test_reverse_pointers(ctx); + if (err) + exit(err); + loc_unref(ctx); fclose(f); diff --git a/src/test-stringpool.c b/src/test-stringpool.c index a94d8f8..5439e56 100644 --- a/src/test-stringpool.c +++ b/src/test-stringpool.c @@ -108,7 +108,7 @@ int main(int argc, char** argv) { free(string); if (pos < 0) { - fprintf(stderr, "Could not add string %d: %m\n", i); + fprintf(stderr, "Could not add string %u: %m\n", i); exit(EXIT_FAILURE); } } diff --git a/src/writer.c b/src/writer.c index a3cb993..9c1db3a 100644 --- a/src/writer.c +++ b/src/writer.c @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -478,7 +479,7 @@ static int loc_database_write_networks(struct loc_writer* writer, } // Write the current node - DEBUG(writer->ctx, "Writing node %p (0 = %d, 1 = %d)\n", + DEBUG(writer->ctx, "Writing node %p (0 = %u, 1 = %u)\n", node, node->index_zero, node->index_one); *offset += fwrite(&db_node, 1, sizeof(db_node), f); @@ -642,6 +643,8 @@ END: } LOC_EXPORT int loc_writer_write(struct loc_writer* writer, FILE* f, enum loc_database_version version) { + size_t bytes_written = 0; + // Check version switch (version) { case LOC_DATABASE_VERSION_UNSET: @@ -652,11 +655,11 @@ LOC_EXPORT int loc_writer_write(struct loc_writer* writer, FILE* f, enum loc_dat break; default: - ERROR(writer->ctx, "Invalid database version: %d\n", version); + ERROR(writer->ctx, "Invalid database version: %u\n", version); return -1; } - DEBUG(writer->ctx, "Writing database in version %d\n", version); + DEBUG(writer->ctx, "Writing database in version %u\n", version); struct loc_database_magic magic; make_magic(writer, &magic, version); @@ -765,7 +768,16 @@ LOC_EXPORT int loc_writer_write(struct loc_writer* writer, FILE* f, enum loc_dat if (r) return r; - fwrite(&header, 1, sizeof(header), f); + bytes_written = fwrite(&header, 1, sizeof(header), f); + if (bytes_written < sizeof(header)) { + ERROR(writer->ctx, "Could not write header: %s\n", strerror(errno)); + return r; + } + + // Seek back to the end + r = fseek(f, 0, SEEK_END); + if (r) + return r; // Flush everything fflush(f); diff --git a/tests/lua/main.lua.in b/tests/lua/main.lua.in new file mode 100755 index 0000000..f436a5e --- /dev/null +++ b/tests/lua/main.lua.in @@ -0,0 +1,216 @@ +#!/usr/bin/lua@LUA_VERSION@ +--[[########################################################################### +# # +# libloc - A library to determine the location of someone on the Internet # +# # +# Copyright (C) 2024 IPFire Development Team # +# # +# This library is free software; you can redistribute it and/or # +# modify it under the terms of the GNU Lesser General Public # +# License as published by the Free Software Foundation; either # +# version 2.1 of the License, or (at your option) any later version. # +# # +# This library is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# Lesser General Public License for more details. # +# # +############################################################################--]] + +luaunit = require("luaunit") + +ENV_TEST_DATABASE = os.getenv("TEST_DATABASE") +ENV_TEST_SIGNING_KEY = os.getenv("TEST_SIGNING_KEY") + +function test_load() + -- Try loading the module + location = require("location") + + -- Print the version + print(location.version()) +end + +log_callback_called = 0 + +function log_callback(level, message) + log_callback_called = true + print("LOG " .. message) +end + +function test_log_callback() + location = require("location") + + -- Set the callback + location.set_log_callback(log_callback) + + -- Enable debugging + location.set_log_level(7) + + -- Perform some random operation + local db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + luaunit.assertIsTrue(log_callback_called) +end + +function test_open_database() + location = require("location") + + -- Open the database + db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + -- Verify + luaunit.assertIsTrue(db:verify(ENV_TEST_SIGNING_KEY)) + + -- Description + luaunit.assertIsString(db:get_description()) + + -- License + luaunit.assertIsString(db:get_license()) + luaunit.assertEquals(db:get_license(), "CC BY-SA 4.0") + + -- Vendor + luaunit.assertIsString(db:get_vendor()) + luaunit.assertEquals(db:get_vendor(), "IPFire Project") +end + +function test_lookup() + location = require("location") + + -- Open the database + db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + -- Perform a lookup + network1 = db:lookup("81.3.27.32") + + luaunit.assertEquals(network1:get_family(), 2) -- AF_INET + luaunit.assertEquals(network1:get_country_code(), "DE") + luaunit.assertEquals(network1:get_asn(), 24679) + + -- Lookup something else + network2 = db:lookup("8.8.8.8") + luaunit.assertIsTrue(network2:has_flag(location.NETWORK_FLAG_ANYCAST)) + luaunit.assertIsFalse(network2:has_flag(location.NETWORK_FLAG_DROP)) +end + +function test_network() + location = require("location") + + n1 = location.Network.new("10.0.0.0/8") + luaunit.assertNotNil(n1) + + -- The ASN should be nul + luaunit.assertNil(n1:get_asn()) + + -- The family should be IPv4 + luaunit.assertEquals(n1:get_family(), 2) -- AF_INET + + -- The country code should be empty + luaunit.assertNil(n1:get_country_code()) +end + +function test_as() + location = require("location") + luaunit.assertNotNil(location) + + -- Create a new AS + as = location.AS.new(12345) + luaunit.assertEquals(as:get_number(), 12345) + luaunit.assertNil(as:get_name()) + + -- Reset + as = nil +end + +function test_fetch_as() + location = require("location") + + -- Open the database + db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + -- Fetch an AS + as = db:get_as(0) + + -- This should not exist + luaunit.assertNil(as) + + -- Fetch something that exists + as = db:get_as(204867) + luaunit.assertEquals(as:get_number(), 204867) + luaunit.assertEquals(as:get_name(), "Lightning Wire Labs GmbH") +end + +function test_country() + location = require("location") + + c1 = location.Country.new("DE") + luaunit.assertNotNil(c1) + luaunit.assertEquals(c1:get_code(), "DE") + luaunit.assertNil(c1:get_name()) + luaunit.assertNil(c1:get_continent_code()) + + c2 = location.Country.new("GB") + luaunit.assertNotNil(c2) + luaunit.assertNotEquals(c1, c2) + + c1 = nil + c2 = nil +end + +function test_fetch_country() + location = require("location") + + -- Open the database + db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + -- Fetch an invalid country + c = db:get_country("XX") + luaunit.assertNil(c) + + -- Fetch something that exists + c = db:get_country("DE") + luaunit.assertEquals(c:get_code(), "DE") + luaunit.assertEquals(c:get_name(), "Germany") +end + +-- This test is not very deterministic but should help to test the GC methods +function test_gc() + print("GC: " .. collectgarbage("collect")) +end + +function test_subnets() + location = require("location") + + -- Open the database + db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + local network = db:lookup("1.1.1.1") + + local subnets = network:subnets() + + luaunit.assertIsTable(subnets) + luaunit.assertEquals(#subnets, 2) + + for i, subnet in ipairs(subnets) do + print(subnet) + end +end + +function test_list_networks() + location = require("location") + + -- Open the database + db = location.Database.open(ENV_TEST_DATABASE) + luaunit.assertNotNil(db) + + for network in db:list_networks() do + print(network, network:reverse_pointer()) + end +end + +os.exit(luaunit.LuaUnit.run()) diff --git a/tests/python/country.py b/tests/python/country.py new file mode 100755 index 0000000..d38d46a --- /dev/null +++ b/tests/python/country.py @@ -0,0 +1,73 @@ +#!/usr/bin/python3 +############################################################################### +# # +# libloc - A library to determine the location of someone on the Internet # +# # +# Copyright (C) 2024 IPFire Development Team # +# # +# This library is free software; you can redistribute it and/or # +# modify it under the terms of the GNU Lesser General Public # +# License as published by the Free Software Foundation; either # +# version 2.1 of the License, or (at your option) any later version. # +# # +# This library is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# Lesser General Public License for more details. # +# # +############################################################################### + +import location +import unittest + +class Test(unittest.TestCase): + def test_properties(self): + c = location.Country("DE") + + # The code should be DE + self.assertEqual(c.code, "DE") + + # All other attributes should return None + self.assertIsNone(c.name) + self.assertIsNone(c.continent_code) + + # Set a name and read it back + c.name = "Germany" + self.assertEqual(c.name, "Germany") + + # Set a continent code and read it back + c.continent_code = "EU" + self.assertEqual(c.continent_code, "EU") + + def test_country_cmp(self): + """ + Performs some comparison tests + """ + c1 = location.Country("DE") + c2 = location.Country("DE") + + # c1 and c2 should be equal + self.assertEqual(c1, c2) + + # We cannot compare against strings for example + self.assertNotEqual(c1, "DE") + + c3 = location.Country("AT") + + # c1 and c3 should not be equal + self.assertNotEqual(c1, c3) + + # c3 comes before c1 (alphabetically) + self.assertGreater(c1, c3) + self.assertLess(c3, c1) + + def test_country_hash(self): + """ + Tests if the hash function works + """ + c = location.Country("DE") + + self.assertTrue(hash(c)) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/python/networks-dedup.py b/tests/python/networks-dedup.py new file mode 100755 index 0000000..5b78a4b --- /dev/null +++ b/tests/python/networks-dedup.py @@ -0,0 +1,165 @@ +#!/usr/bin/python3 +############################################################################### +# # +# libloc - A library to determine the location of someone on the Internet # +# # +# Copyright (C) 2024 IPFire Development Team # +# # +# This library is free software; you can redistribute it and/or # +# modify it under the terms of the GNU Lesser General Public # +# License as published by the Free Software Foundation; either # +# version 2.1 of the License, or (at your option) any later version. # +# # +# This library is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# Lesser General Public License for more details. # +# # +############################################################################### + +import location +import os +import tempfile +import unittest + +class Test(unittest.TestCase): + def setUp(self): + # Show even very large diffs + self.maxDiff = None + + def __test(self, inputs, outputs=None): + """ + Takes a list of networks that are written to the database and + compares the result with the second argument. + """ + if outputs is None: + outputs = [network for network, cc, asn in inputs] + + with tempfile.NamedTemporaryFile() as f: + w = location.Writer() + + # Add all inputs + for network, cc, asn in inputs: + n = w.add_network(network) + + # Add CC + if cc: + n.country_code = cc + + # Add ASN + if asn: + n.asn = asn + + # Write file + w.write(f.name) + + # Re-open the database + db = location.Database(f.name) + + # Check if the output matches what we expect + self.assertCountEqual( + outputs, ["%s" % network for network in db.networks], + ) + + def test_dudup_simple(self): + """ + Creates a couple of redundant networks and expects fewer being written + """ + self.__test( + ( + ("10.0.0.0/8", None, None), + ("10.0.0.0/16", None, None), + ("10.0.0.0/24", None, None), + ), + + # Everything should be put into the /8 subnet + ("10.0.0.0/8",), + ) + + def test_dedup_noop(self): + """ + Nothing should be changed here + """ + networks = ( + ("10.0.0.0/8", None, None), + ("20.0.0.0/8", None, None), + ("30.0.0.0/8", None, None), + ("40.0.0.0/8", None, None), + ("50.0.0.0/8", None, None), + ("60.0.0.0/8", None, None), + ("70.0.0.0/8", None, None), + ("80.0.0.0/8", None, None), + ("90.0.0.0/8", None, None), + ) + + # The input should match the output + self.__test(networks) + + def test_dedup_with_properties(self): + """ + A more complicated deduplication test where properties have been set + """ + # Nothing should change here because of different countries + self.__test( + ( + ("10.0.0.0/8", "DE", None), + ("10.0.0.0/16", "AT", None), + ("10.0.0.0/24", "DE", None), + ), + ) + + # Nothing should change here because of different ASNs + self.__test( + ( + ("10.0.0.0/8", None, 1000), + ("10.0.0.0/16", None, 2000), + ("10.0.0.0/24", None, 1000), + ), + ) + + # Everything can be merged again + self.__test( + ( + ("10.0.0.0/8", "DE", 1000), + ("10.0.0.0/16", "DE", 1000), + ("10.0.0.0/24", "DE", 1000), + ), + ("10.0.0.0/8",), + ) + + def test_merge(self): + """ + Checks whether the merging algorithm works + """ + self.__test( + ( + ("10.0.0.0/9", None, None), + ("10.128.0.0/9", None, None), + ), + ("10.0.0.0/8",), + ) + + def test_bug13236(self): + self.__test( + ( + ("209.38.0.0/16", "US", None), + ("209.38.1.0/24", "US", 14061), + ("209.38.160.0/22", "US", 14061), + ("209.38.164.0/22", "US", 14061), + ("209.38.168.0/22", "US", 14061), + ("209.38.172.0/22", "US", 14061), + ("209.38.176.0/20", "US", 14061), + ("209.38.192.0/19", "US", 14061), + ("209.38.224.0/19", "US", 14061), + ), + ( + "209.38.0.0/16", + "209.38.1.0/24", + "209.38.160.0/19", + "209.38.192.0/18", + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/copy.py b/tools/copy.py new file mode 100644 index 0000000..39129c0 --- /dev/null +++ b/tools/copy.py @@ -0,0 +1,108 @@ +#!/usr/bin/python3 +############################################################################### +# # +# libloc - A library to determine the location of someone on the Internet # +# # +# Copyright (C) 2024 IPFire Development Team # +# # +# This library is free software; you can redistribute it and/or # +# modify it under the terms of the GNU Lesser General Public # +# License as published by the Free Software Foundation; either # +# version 2.1 of the License, or (at your option) any later version. # +# # +# This library is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# Lesser General Public License for more details. # +# # +############################################################################### + +import argparse + +import location +from location.i18n import _ + +flags = ( + location.NETWORK_FLAG_ANONYMOUS_PROXY, + location.NETWORK_FLAG_SATELLITE_PROVIDER, + location.NETWORK_FLAG_ANYCAST, + location.NETWORK_FLAG_DROP, +) + +def copy_all(db, writer): + # Copy vendor + if db.vendor: + writer.vendor = db.vendor + + # Copy description + if db.description: + writer.description = db.description + + # Copy license + if db.license: + writer.license = db.license + + # Copy all ASes + for old in db.ases: + new = writer.add_as(old.number) + new.name = old.name + + # Copy all networks + for old in db.networks: + new = writer.add_network("%s" % old) + + # Copy country code + new.country_code = old.country_code + + # Copy ASN + if old.asn: + new.asn = old.asn + + # Copy flags + for flag in flags: + if old.has_flag(flag): + new.set_flag(flag) + + # Copy countries + for old in db.countries: + new = writer.add_country(old.code) + + # Copy continent code + new.continent_code = old.continent_code + + # Copy name + new.name = old.name + +def main(): + """ + Main Function + """ + parser = argparse.ArgumentParser( + description=_("Copies a location database"), + ) + + # Input File + parser.add_argument("input-file", help=_("File to read")) + + # Output File + parser.add_argument("output-file", help=_("File to write")) + + # Parse arguments + args = parser.parse_args() + + input_file = getattr(args, "input-file") + output_file = getattr(args, "output-file") + + # Open the database + db = location.Database(input_file) + + # Create a new writer + writer = location.Writer() + + # Copy everything + copy_all(db, writer) + + # Write the new file + writer.write(output_file) + +main()